|
8 | 8 |
|
9 | 9 | #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
|
10 | 10 | #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
| 11 | +#include "mlir/Dialect/Affine/Utils.h" |
11 | 12 | #include "mlir/Dialect/Arith/Utils/Utils.h"
|
12 | 13 | #include "mlir/Dialect/Linalg/IR/Linalg.h"
|
13 | 14 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
14 | 15 | #include "mlir/Dialect/Tensor/IR/Tensor.h"
|
| 16 | +#include "mlir/Dialect/Tensor/Utils/Utils.h" |
| 17 | +#include "mlir/Dialect/Utils/IndexingUtils.h" |
15 | 18 | #include "mlir/Interfaces/TilingInterface.h"
|
16 | 19 |
|
17 | 20 | using namespace mlir;
|
@@ -68,6 +71,145 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
|
68 | 71 | }
|
69 | 72 | };
|
70 | 73 |
|
| 74 | +struct PackOpTiling |
| 75 | + : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { |
| 76 | + |
| 77 | + SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 78 | + // Note that here we only consider untiled dimensions and outer tiled data |
| 79 | + // dimensions, the inner tiled data dimensions are materialized when |
| 80 | + // building the body of the operation. |
| 81 | + auto packOp = cast<PackOp>(op); |
| 82 | + SmallVector<utils::IteratorType> iteratorTypes( |
| 83 | + packOp.getSourceRank(), utils::IteratorType::parallel); |
| 84 | + return iteratorTypes; |
| 85 | + } |
| 86 | + |
| 87 | + SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 88 | + OpBuilder::InsertionGuard guard(b); |
| 89 | + auto packOp = cast<PackOp>(op); |
| 90 | + Location loc = packOp.getLoc(); |
| 91 | + int64_t rank = packOp.getSourceRank(); |
| 92 | + Value zero = b.create<arith::ConstantIndexOp>(loc, 0); |
| 93 | + Value one = b.create<arith::ConstantIndexOp>(loc, 1); |
| 94 | + ReifiedRankedShapedTypeDims resultShape; |
| 95 | + (void)packOp.reifyResultShapes(b, resultShape); |
| 96 | + SmallVector<Range> loopRanges(rank); |
| 97 | + for (auto dim : llvm::seq<int64_t>(0, rank)) { |
| 98 | + loopRanges[dim].offset = zero; |
| 99 | + loopRanges[dim].stride = one; |
| 100 | + loopRanges[dim].size = resultShape[0][dim]; |
| 101 | + } |
| 102 | + return loopRanges; |
| 103 | + } |
| 104 | + |
| 105 | + SmallVector<Operation *> |
| 106 | + getTiledImplementation(Operation *op, OpBuilder &b, |
| 107 | + ArrayRef<OpFoldResult> offsets, |
| 108 | + ArrayRef<OpFoldResult> sizes) const { |
| 109 | + auto packOp = cast<PackOp>(op); |
| 110 | + Location loc = packOp.getLoc(); |
| 111 | + |
| 112 | + // The tiling is applied on interchanged dimensions. We have to undo the |
| 113 | + // interchange to map sizes and offsets to the original input. |
| 114 | + int64_t inputRank = packOp.getSourceRank(); |
| 115 | + ArrayRef<int64_t> dimsToOuterBlock(packOp.getOuterDimsPerm()); |
| 116 | + SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end()); |
| 117 | + SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end()); |
| 118 | + if (!dimsToOuterBlock.empty()) { |
| 119 | + SmallVector<int64_t> inversedPerm = |
| 120 | + invertPermutationVector(dimsToOuterBlock); |
| 121 | + applyPermutationToVector<OpFoldResult>(origOffsets, inversedPerm); |
| 122 | + applyPermutationToVector<OpFoldResult>(origSizes, inversedPerm); |
| 123 | + } |
| 124 | + |
| 125 | + DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 126 | + packOp.getDimAndTileMapping(); |
| 127 | + SmallVector<OpFoldResult> srcDimValues = |
| 128 | + tensor::createDimValues(b, loc, packOp.getSource()); |
| 129 | + SmallVector<OpFoldResult> inputIndices, inputSizes; |
| 130 | + for (auto dim : llvm::seq<int64_t>(0, inputRank)) { |
| 131 | + using AV = AffineValueExpr; |
| 132 | + AffineBuilder ab(b, loc); |
| 133 | + AffineExpr dim0, dim1, sym; |
| 134 | + bindDims(b.getContext(), dim0, dim1); |
| 135 | + bindSymbols(b.getContext(), sym); |
| 136 | + if (dimAndTileMapping.count(dim)) { |
| 137 | + // If the data dimension is tiled, the i-th index is the product of |
| 138 | + // offset_i and tile_i, and the i-th size is the product of sizes_i and |
| 139 | + // tile_i. |
| 140 | + auto avOffset = AV(dim0).bind(origOffsets[dim]); |
| 141 | + auto avSize = AV(dim0).bind(origSizes[dim]); |
| 142 | + auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); |
| 143 | + inputIndices.push_back(ab.mul(avOffset, avTileSize)); |
| 144 | + inputSizes.push_back(ab.mul(avSize, avTileSize)); |
| 145 | + } else { |
| 146 | + inputIndices.push_back(origOffsets[dim]); |
| 147 | + inputSizes.push_back(origSizes[dim]); |
| 148 | + } |
| 149 | + |
| 150 | + // Limit the size of the input operand for incomplete tiles. |
| 151 | + OpFoldResult dimSize = srcDimValues[dim]; |
| 152 | + auto avDimSize = AV(dim0).bind(dimSize); |
| 153 | + auto avInputIdx = AV(dim1).bind(inputIndices.back()); |
| 154 | + inputSizes.back() = |
| 155 | + ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); |
| 156 | + } |
| 157 | + |
| 158 | + auto oneAttr = b.getI64IntegerAttr(1); |
| 159 | + SmallVector<OpFoldResult> strides(inputRank, oneAttr); |
| 160 | + |
| 161 | + SmallVector<Value> tiledOperands; |
| 162 | + tiledOperands.push_back(b.create<ExtractSliceOp>( |
| 163 | + loc, packOp.getSource(), inputIndices, inputSizes, strides)); |
| 164 | + |
| 165 | + SmallVector<OpFoldResult> outputOffsets, outputSizes; |
| 166 | + if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, |
| 167 | + outputSizes))) |
| 168 | + return {}; |
| 169 | + |
| 170 | + strides.append(packOp.getDestRank() - inputRank, oneAttr); |
| 171 | + auto extractSlice = b.create<ExtractSliceOp>( |
| 172 | + loc, packOp.getDest(), outputOffsets, outputSizes, strides); |
| 173 | + tiledOperands.push_back(extractSlice); |
| 174 | + |
| 175 | + if (auto val = packOp.getPaddingValue()) |
| 176 | + tiledOperands.push_back(val); |
| 177 | + for (auto tile : packOp.getInnerTiles()) |
| 178 | + tiledOperands.push_back(tile); |
| 179 | + |
| 180 | + Operation *tiledPackOp = b.create<PackOp>( |
| 181 | + loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); |
| 182 | + |
| 183 | + return {tiledPackOp}; |
| 184 | + } |
| 185 | + |
| 186 | + LogicalResult |
| 187 | + getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 188 | + ArrayRef<OpFoldResult> offsets, |
| 189 | + ArrayRef<OpFoldResult> sizes, |
| 190 | + SmallVector<OpFoldResult> &resultOffsets, |
| 191 | + SmallVector<OpFoldResult> &resultSizes) const { |
| 192 | + // The iteration domain is over outer dimensions of packed layout. In this |
| 193 | + // context, the outer dimensions of `resultOffsets` are `offsets`. The |
| 194 | + // inner dimensions of `resultOffsets` are zeros because tiling is not |
| 195 | + // applied to them. |
| 196 | + auto packOp = cast<PackOp>(op); |
| 197 | + int64_t inputRank = packOp.getSourceRank(); |
| 198 | + int64_t outputRank = packOp.getDestRank(); |
| 199 | + auto zeroAttr = b.getI64IntegerAttr(0); |
| 200 | + resultOffsets.assign(offsets.begin(), offsets.end()); |
| 201 | + resultOffsets.append(outputRank - inputRank, zeroAttr); |
| 202 | + |
| 203 | + ReifiedRankedShapedTypeDims outputShape; |
| 204 | + (void)packOp.reifyResultShapes(b, outputShape); |
| 205 | + resultSizes.assign(sizes.begin(), sizes.end()); |
| 206 | + for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) |
| 207 | + resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim])); |
| 208 | + |
| 209 | + return success(); |
| 210 | + } |
| 211 | +}; |
| 212 | + |
71 | 213 | } // namespace
|
72 | 214 |
|
73 | 215 | Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
|
@@ -282,5 +424,6 @@ void mlir::tensor::registerTilingInterfaceExternalModels(
|
282 | 424 | DialectRegistry ®istry) {
|
283 | 425 | registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
|
284 | 426 | tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
|
| 427 | + tensor::PackOp::attachInterface<PackOpTiling>(*ctx); |
285 | 428 | });
|
286 | 429 | }
|
0 commit comments