|
| 1 | +//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" |
| 10 | + |
| 11 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 12 | +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
| 13 | +#include "mlir/Dialect/ArmSME/Utils/Utils.h" |
| 14 | +#include "mlir/Pass/Pass.h" |
| 15 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 16 | + |
| 17 | +namespace mlir { |
| 18 | +#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS |
| 19 | +#include "mlir/Conversion/Passes.h.inc" |
| 20 | +} // namespace mlir |
| 21 | + |
| 22 | +#define DEBUG_TYPE "arith-to-arm-sme" |
| 23 | + |
| 24 | +using namespace mlir; |
| 25 | + |
| 26 | +//===----------------------------------------------------------------------===// |
| 27 | +// Conversion helpers |
| 28 | +//===----------------------------------------------------------------------===// |
| 29 | + |
| 30 | +/// Returns true if 'val' is a splat of zero, false otherwise. |
| 31 | +static bool isSplatZero(Type elemType, DenseElementsAttr val) { |
| 32 | + if (llvm::isa<FloatType>(elemType)) |
| 33 | + return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); |
| 34 | + if (llvm::isa<IntegerType>(elemType)) |
| 35 | + return val && val.isSplat() && val.getSplatValue<APInt>().isZero(); |
| 36 | + return false; |
| 37 | +} |
| 38 | + |
| 39 | +namespace { |
| 40 | + |
| 41 | +//===----------------------------------------------------------------------===// |
| 42 | +// ConstantOp |
| 43 | +//===----------------------------------------------------------------------===// |
| 44 | + |
| 45 | +/// Conversion pattern for dense arith.constant. |
| 46 | +struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { |
| 47 | + using OpRewritePattern<arith::ConstantOp>::OpRewritePattern; |
| 48 | + |
| 49 | + LogicalResult matchAndRewrite(arith::ConstantOp constantOp, |
| 50 | + PatternRewriter &rewriter) const final { |
| 51 | + auto tileType = dyn_cast<VectorType>(constantOp.getType()); |
| 52 | + if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) |
| 53 | + return failure(); |
| 54 | + |
| 55 | + auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); |
| 56 | + if (!denseAttr || !denseAttr.isSplat()) |
| 57 | + return failure(); |
| 58 | + |
| 59 | + auto tileElementType = tileType.getElementType(); |
| 60 | + |
| 61 | + // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op. |
| 62 | + if (isSplatZero(tileElementType, denseAttr)) { |
| 63 | + rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType); |
| 64 | + return success(); |
| 65 | + } |
| 66 | + |
| 67 | + // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice' |
| 68 | + // ops that broadcast the constant to each tile slice. |
| 69 | + auto loc = constantOp.getLoc(); |
| 70 | + |
| 71 | + // To fill a tile with a constant, we create a 1-D splat of the constant, |
| 72 | + // then move that into each tile slice (the largest unit we can set at once, |
| 73 | + // outside of operations like the outerproduct). |
| 74 | + VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
| 75 | + auto denseAttr1D = DenseElementsAttr::get( |
| 76 | + tileSliceType, denseAttr.getSplatValue<Attribute>()); |
| 77 | + auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D); |
| 78 | + |
| 79 | + auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); |
| 80 | + auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, |
| 81 | + Value currentTile) { |
| 82 | + // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile |
| 83 | + // slice. |
| 84 | + auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>( |
| 85 | + loc, tileType, constantOp1D, currentTile, tileSliceIndex); |
| 86 | + return nextTile.getResult(); |
| 87 | + }; |
| 88 | + auto forOp = mlir::arm_sme::createLoopOverTileSlices( |
| 89 | + rewriter, loc, initTile, makeLoopBody); |
| 90 | + rewriter.replaceOp(constantOp, forOp.getResult(0)); |
| 91 | + |
| 92 | + return success(); |
| 93 | + } |
| 94 | +}; |
| 95 | + |
| 96 | +} // namespace |
| 97 | + |
| 98 | +//===----------------------------------------------------------------------===// |
| 99 | +// Pattern population |
| 100 | +//===----------------------------------------------------------------------===// |
| 101 | + |
| 102 | +void mlir::arith::populateArithToArmSMEConversionPatterns( |
| 103 | + RewritePatternSet &patterns) { |
| 104 | + patterns.add<ConstantOpToArmSMELowering>(patterns.getContext()); |
| 105 | +} |
| 106 | + |
| 107 | +//===----------------------------------------------------------------------===// |
| 108 | +// Pass definition |
| 109 | +//===----------------------------------------------------------------------===// |
| 110 | + |
| 111 | +namespace { |
| 112 | +struct ArithToArmSMEConversionPass final |
| 113 | + : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> { |
| 114 | + using impl::ArithToArmSMEConversionPassBase< |
| 115 | + ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase; |
| 116 | + |
| 117 | + void runOnOperation() override { |
| 118 | + RewritePatternSet patterns(&getContext()); |
| 119 | + arith::populateArithToArmSMEConversionPatterns(patterns); |
| 120 | + if (failed( |
| 121 | + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
| 122 | + return signalPassFailure(); |
| 123 | + } |
| 124 | +}; |
| 125 | +} // namespace |
0 commit comments