Skip to content

Commit 0d03ba6

Browse files
committed
[mlir][tensor] Implement TilingInterface for tensor.pack op.
We can compute the offsets and sizes for the slice of input because the iteration domain is defined over outer loops. If the dimension is tiled, the i-th index is the product of offset_i and inner_tile_i. Different from tiling a pad op, we do not have to deal with reading zero data from input. Because the tiling sizes are indicated to packed outer dimensions. We will read either the entire tile or partial tile for each packed tile. The scf.if and tensor.generate ops are not needed in this context. Co-authored-by: Lorenzo Chelini <[email protected]> Reviewed By: rengolin, mravishankar Differential Revision: https://reviews.llvm.org/D138631
1 parent 7ca32bd commit 0d03ba6

File tree

5 files changed

+413
-0
lines changed

5 files changed

+413
-0
lines changed

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_AFFINE_UTILS_H
1515

1616
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1718

1819
namespace mlir {
1920

@@ -328,6 +329,56 @@ FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
328329
/// that would change the read within `memOp`.
329330
template <typename EffectType, typename T>
330331
bool hasNoInterveningEffect(Operation *start, T memOp);
332+
333+
struct AffineValueExpr {
334+
explicit AffineValueExpr(AffineExpr e) : e(e) {}
335+
AffineValueExpr bind(Value v) {
336+
this->v = v;
337+
return *this;
338+
}
339+
AffineValueExpr bind(OpFoldResult v) {
340+
this->v = v;
341+
return *this;
342+
}
343+
operator AffineExpr() const { return e; }
344+
operator OpFoldResult() const { return v; }
345+
AffineExpr e;
346+
OpFoldResult v;
347+
};
348+
349+
/// Helper struct to build simple AffineValueExprs with minimal type inference
350+
/// support.
351+
struct AffineBuilder {
352+
AffineBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
353+
OpFoldResult add(AffineValueExpr lhs, AffineValueExpr rhs) {
354+
return makeComposedFoldedAffineApply(b, loc, {lhs.e + rhs.e}, {lhs, rhs});
355+
}
356+
OpFoldResult sub(AffineValueExpr lhs, AffineValueExpr rhs) {
357+
return makeComposedFoldedAffineApply(b, loc, {lhs.e - rhs.e}, {lhs, rhs});
358+
}
359+
OpFoldResult mul(AffineValueExpr lhs, AffineValueExpr rhs) {
360+
return makeComposedFoldedAffineApply(b, loc, {lhs.e * rhs.e}, {lhs, rhs});
361+
}
362+
OpFoldResult ceil(AffineValueExpr lhs, AffineValueExpr rhs) {
363+
return makeComposedFoldedAffineApply(b, loc, {lhs.e.ceilDiv(rhs.e)},
364+
{lhs, rhs});
365+
}
366+
OpFoldResult min(ArrayRef<OpFoldResult> vals) {
367+
return makeComposedFoldedAffineMin(
368+
b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), b.getContext()),
369+
vals);
370+
}
371+
OpFoldResult max(ArrayRef<OpFoldResult> vals) {
372+
return makeComposedFoldedAffineMax(
373+
b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), b.getContext()),
374+
vals);
375+
}
376+
377+
private:
378+
OpBuilder &b;
379+
Location loc;
380+
};
381+
331382
} // namespace mlir
332383

333384
#endif // MLIR_DIALECT_AFFINE_UTILS_H

mlir/lib/Dialect/Tensor/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ add_mlir_dialect_library(MLIRTensorTilingInterfaceImpl
5757

5858
LINK_LIBS PUBLIC
5959
MLIRAffineDialect
60+
MLIRDialectUtils
6061
MLIRIR
6162
MLIRLinalgDialect
6263
MLIRSCFDialect
6364
MLIRSupport
6465
MLIRTensorDialect
66+
MLIRTensorUtils
6567
MLIRTilingInterface
6668
)

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88

99
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
1010
#include "mlir/Dialect/Affine/IR/AffineOps.h"
11+
#include "mlir/Dialect/Affine/Utils.h"
1112
#include "mlir/Dialect/Arith/Utils/Utils.h"
1213
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1314
#include "mlir/Dialect/SCF/IR/SCF.h"
1415
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir/Dialect/Tensor/Utils/Utils.h"
17+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1518
#include "mlir/Interfaces/TilingInterface.h"
1619

1720
using namespace mlir;
@@ -68,6 +71,145 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
6871
}
6972
};
7073

74+
struct PackOpTiling
75+
: public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
76+
77+
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
78+
// Note that here we only consider untiled dimensions and outer tiled data
79+
// dimensions, the inner tiled data dimensions are materialized when
80+
// building the body of the operation.
81+
auto packOp = cast<PackOp>(op);
82+
SmallVector<utils::IteratorType> iteratorTypes(
83+
packOp.getSourceRank(), utils::IteratorType::parallel);
84+
return iteratorTypes;
85+
}
86+
87+
SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
88+
OpBuilder::InsertionGuard guard(b);
89+
auto packOp = cast<PackOp>(op);
90+
Location loc = packOp.getLoc();
91+
int64_t rank = packOp.getSourceRank();
92+
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
93+
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
94+
ReifiedRankedShapedTypeDims resultShape;
95+
(void)packOp.reifyResultShapes(b, resultShape);
96+
SmallVector<Range> loopRanges(rank);
97+
for (auto dim : llvm::seq<int64_t>(0, rank)) {
98+
loopRanges[dim].offset = zero;
99+
loopRanges[dim].stride = one;
100+
loopRanges[dim].size = resultShape[0][dim];
101+
}
102+
return loopRanges;
103+
}
104+
105+
SmallVector<Operation *>
106+
getTiledImplementation(Operation *op, OpBuilder &b,
107+
ArrayRef<OpFoldResult> offsets,
108+
ArrayRef<OpFoldResult> sizes) const {
109+
auto packOp = cast<PackOp>(op);
110+
Location loc = packOp.getLoc();
111+
112+
// The tiling is applied on interchanged dimensions. We have to undo the
113+
// interchange to map sizes and offsets to the original input.
114+
int64_t inputRank = packOp.getSourceRank();
115+
ArrayRef<int64_t> dimsToOuterBlock(packOp.getOuterDimsPerm());
116+
SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end());
117+
SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end());
118+
if (!dimsToOuterBlock.empty()) {
119+
SmallVector<int64_t> inversedPerm =
120+
invertPermutationVector(dimsToOuterBlock);
121+
applyPermutationToVector<OpFoldResult>(origOffsets, inversedPerm);
122+
applyPermutationToVector<OpFoldResult>(origSizes, inversedPerm);
123+
}
124+
125+
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
126+
packOp.getDimAndTileMapping();
127+
SmallVector<OpFoldResult> srcDimValues =
128+
tensor::createDimValues(b, loc, packOp.getSource());
129+
SmallVector<OpFoldResult> inputIndices, inputSizes;
130+
for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
131+
using AV = AffineValueExpr;
132+
AffineBuilder ab(b, loc);
133+
AffineExpr dim0, dim1, sym;
134+
bindDims(b.getContext(), dim0, dim1);
135+
bindSymbols(b.getContext(), sym);
136+
if (dimAndTileMapping.count(dim)) {
137+
// If the data dimension is tiled, the i-th index is the product of
138+
// offset_i and tile_i, and the i-th size is the product of sizes_i and
139+
// tile_i.
140+
auto avOffset = AV(dim0).bind(origOffsets[dim]);
141+
auto avSize = AV(dim0).bind(origSizes[dim]);
142+
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
143+
inputIndices.push_back(ab.mul(avOffset, avTileSize));
144+
inputSizes.push_back(ab.mul(avSize, avTileSize));
145+
} else {
146+
inputIndices.push_back(origOffsets[dim]);
147+
inputSizes.push_back(origSizes[dim]);
148+
}
149+
150+
// Limit the size of the input operand for incomplete tiles.
151+
OpFoldResult dimSize = srcDimValues[dim];
152+
auto avDimSize = AV(dim0).bind(dimSize);
153+
auto avInputIdx = AV(dim1).bind(inputIndices.back());
154+
inputSizes.back() =
155+
ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
156+
}
157+
158+
auto oneAttr = b.getI64IntegerAttr(1);
159+
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
160+
161+
SmallVector<Value> tiledOperands;
162+
tiledOperands.push_back(b.create<ExtractSliceOp>(
163+
loc, packOp.getSource(), inputIndices, inputSizes, strides));
164+
165+
SmallVector<OpFoldResult> outputOffsets, outputSizes;
166+
if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
167+
outputSizes)))
168+
return {};
169+
170+
strides.append(packOp.getDestRank() - inputRank, oneAttr);
171+
auto extractSlice = b.create<ExtractSliceOp>(
172+
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
173+
tiledOperands.push_back(extractSlice);
174+
175+
if (auto val = packOp.getPaddingValue())
176+
tiledOperands.push_back(val);
177+
for (auto tile : packOp.getInnerTiles())
178+
tiledOperands.push_back(tile);
179+
180+
Operation *tiledPackOp = b.create<PackOp>(
181+
loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
182+
183+
return {tiledPackOp};
184+
}
185+
186+
LogicalResult
187+
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
188+
ArrayRef<OpFoldResult> offsets,
189+
ArrayRef<OpFoldResult> sizes,
190+
SmallVector<OpFoldResult> &resultOffsets,
191+
SmallVector<OpFoldResult> &resultSizes) const {
192+
// The iteration domain is over outer dimensions of packed layout. In this
193+
// context, the outer dimensions of `resultOffsets` are `offsets`. The
194+
// inner dimensions of `resultOffsets` are zeros because tiling is not
195+
// applied to them.
196+
auto packOp = cast<PackOp>(op);
197+
int64_t inputRank = packOp.getSourceRank();
198+
int64_t outputRank = packOp.getDestRank();
199+
auto zeroAttr = b.getI64IntegerAttr(0);
200+
resultOffsets.assign(offsets.begin(), offsets.end());
201+
resultOffsets.append(outputRank - inputRank, zeroAttr);
202+
203+
ReifiedRankedShapedTypeDims outputShape;
204+
(void)packOp.reifyResultShapes(b, outputShape);
205+
resultSizes.assign(sizes.begin(), sizes.end());
206+
for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
207+
resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim]));
208+
209+
return success();
210+
}
211+
};
212+
71213
} // namespace
72214

73215
Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
@@ -282,5 +424,6 @@ void mlir::tensor::registerTilingInterfaceExternalModels(
282424
DialectRegistry &registry) {
283425
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
284426
tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
427+
tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
285428
});
286429
}

0 commit comments

Comments
 (0)