16
16
#include " mlir/Dialect/Tensor/Utils/Utils.h"
17
17
#include " mlir/Dialect/Utils/IndexingUtils.h"
18
18
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
19
+ #include " llvm/Support/Debug.h"
19
20
20
21
namespace mlir {
21
22
#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
@@ -29,10 +30,66 @@ using namespace mlir::linalg;
29
30
30
31
namespace {
31
32
33
+ // The struct contains the infomation about mapping packing information to
34
+ // the iteration domain of Linalg ops.
35
+ struct PackInfo {
36
+ int64_t getNumTiledLoops () const { return tileToPointMapping.size (); };
37
+ // InnerDimsPos on iteration domain, which follows the order in pack ops.
38
+ SmallVector<int64_t > tiledDimsPos;
39
+ // The sizes of tiling data dimensions on iteration domain.
40
+ llvm::DenseMap<int64_t , OpFoldResult> domainDimAndTileMapping;
41
+ // The mapping from a dimension of iteration domain to the corresponding inner
42
+ // tiling dimension on iteration domain.
43
+ llvm::DenseMap<int64_t , int64_t > tileToPointMapping;
44
+ // The permutation of outer dims (on domain).
45
+ SmallVector<int64_t > outerDimsOnDomainPerm;
46
+ Optional<Value> paddingValue;
47
+ };
48
+
49
+ static PackInfo getPackingInfoFromConsumer (
50
+ AffineMap indexingMap, ArrayRef<OpFoldResult> innerTileSizes,
51
+ ArrayRef<int64_t > innerDimsPos, ArrayRef<int64_t > outerDimsPerm,
52
+ Optional<Value> paddingValue = llvm::None) {
53
+ LLVM_DEBUG (
54
+ { llvm::dbgs () << " --- Construct PackInfo From A Consumer ---\n " ; });
55
+ PackInfo packInfo;
56
+ packInfo.paddingValue = paddingValue;
57
+ int64_t origNumDims = indexingMap.getNumDims ();
58
+ SmallVector<AffineExpr> exprs (indexingMap.getResults ());
59
+ for (auto [index, innerDimPos, tileSize] :
60
+ llvm::zip_equal (llvm::seq<unsigned >(0 , innerDimsPos.size ()),
61
+ innerDimsPos, innerTileSizes)) {
62
+ int64_t domainDimPos =
63
+ exprs[innerDimPos].cast <AffineDimExpr>().getPosition ();
64
+ packInfo.tiledDimsPos .push_back (domainDimPos);
65
+ packInfo.domainDimAndTileMapping [domainDimPos] = tileSize;
66
+ packInfo.tileToPointMapping [domainDimPos] = origNumDims + index;
67
+ LLVM_DEBUG ({
68
+ llvm::dbgs () << " map innerDimPos=" << innerDimPos
69
+ << " to iteration dimension (d" << domainDimPos << " , d"
70
+ << packInfo.tileToPointMapping [domainDimPos]
71
+ << " ), which has size=("
72
+ << packInfo.domainDimAndTileMapping [domainDimPos] << " )\n " ;
73
+ });
74
+ }
75
+
76
+ for (auto dim : outerDimsPerm)
77
+ packInfo.outerDimsOnDomainPerm .push_back (indexingMap.getDimPosition (dim));
78
+ if (!packInfo.outerDimsOnDomainPerm .empty ()) {
79
+ LLVM_DEBUG ({
80
+ llvm::dbgs () << " map outer dimsDimsPerm to " ;
81
+ for (auto dim : packInfo.outerDimsOnDomainPerm )
82
+ llvm::dbgs () << dim << " " ;
83
+ llvm::dbgs () << " \n " ;
84
+ });
85
+ }
86
+
87
+ return packInfo;
88
+ }
89
+
32
90
// / Returns a tuple for packed operand and indexing_map with the assumptions:
33
91
// / 1) The generic op is the producer of the pack op.
34
92
// / 2) The generic op has only one result.
35
- // / 3) The indexing map of the output operand is identity.
36
93
// / If the operand is a scalar or packing dimensions are all irrelevant to the
37
94
// / operand, the opreand and the updated indexing map will be returned.
38
95
// / Otherwise, it returns the packed operand and the updated indexing map. E.g.,
@@ -62,62 +119,57 @@ namespace {
62
119
// / inner_tiles = [8]
63
120
// / into %init : tensor<?xf32> -> tensor<?x8xf32>
64
121
static std::tuple<Value, AffineMap>
65
- getOrCreatePackedViewOfOperand (OpBuilder &b, Location loc,
66
- tensor::PackOp packOp, GenericOp genericOp,
67
- OpOperand *opOperand) {
68
- int numOrigLoops = genericOp.getNumLoops ();
69
- int64_t numInnerLoops = packOp.getInnerDimsPos ().size ();
122
+ getOrCreatePackedViewOfOperand (OpBuilder &b, Location loc, PackInfo packInfo,
123
+ GenericOp genericOp, OpOperand *opOperand) {
124
+ int64_t numOrigLoops = genericOp.getNumLoops ();
125
+ int64_t numInnerLoops = packInfo.getNumTiledLoops ();
70
126
int64_t numLoops = numOrigLoops + numInnerLoops;
71
127
AffineMap origIndexingMap = genericOp.getMatchingIndexingMap (opOperand);
128
+ llvm::DenseMap<int64_t , int64_t > domainDimToOperandDim;
72
129
SmallVector<AffineExpr> exprs (origIndexingMap.getResults ());
73
-
74
130
if (genericOp.isScalar (opOperand))
75
- return std::make_tuple (
76
- opOperand->get (),
77
- AffineMap::get (numLoops, 0 , exprs, packOp.getContext ()));
78
-
79
- llvm::SetVector<int64_t > innerDimsPosSet (packOp.getInnerDimsPos ().begin (),
80
- packOp.getInnerDimsPos ().end ());
81
- // Mapping from AffinDimExpr of indexing maps to the operand shape dimension.
82
- DenseMap<int64_t , int64_t > iterMapToDim;
83
- for (auto [index, expr] : llvm::enumerate (origIndexingMap.getResults ())) {
131
+ return std::make_tuple (opOperand->get (),
132
+ AffineMap::get (numLoops, 0 , exprs, b.getContext ()));
133
+
134
+ // Step 1. Construct the information of packing data dimensions; append inner
135
+ // dimensions to the indexing maps for the operand.
136
+ for (auto [index, expr] : llvm::enumerate (exprs)) {
84
137
int64_t dimPos = expr.cast <AffineDimExpr>().getPosition ();
85
- if (!innerDimsPosSet.contains (dimPos))
86
- continue ;
87
- iterMapToDim[dimPos] = index;
138
+ domainDimToOperandDim[dimPos] = index;
88
139
}
89
-
90
- // Construct the information of packing data dimensions and new indexing maps
91
- // for the operand.
92
140
SmallVector<int64_t > innerDimsPos;
93
141
SmallVector<OpFoldResult> innerTileSizes;
94
- for (auto [index, value] : llvm::enumerate (
95
- llvm::zip (packOp.getInnerDimsPos (), packOp.getMixedTiles ()))) {
96
- int64_t dimPos = std::get<0 >(value);
97
- if (!iterMapToDim.count (dimPos))
142
+ for (auto dimPos : packInfo.tiledDimsPos ) {
143
+ if (!domainDimToOperandDim.count (dimPos))
98
144
continue ;
99
- innerDimsPos.push_back (iterMapToDim[dimPos]);
100
- innerTileSizes.push_back (std::get<1 >(value));
101
- exprs.push_back (b.getAffineDimExpr (numOrigLoops + index));
145
+ int64_t index = domainDimToOperandDim[dimPos];
146
+ innerTileSizes.push_back (packInfo.domainDimAndTileMapping [dimPos]);
147
+ innerDimsPos.push_back (index);
148
+ exprs.push_back (b.getAffineDimExpr (packInfo.tileToPointMapping [dimPos]));
102
149
}
103
- auto indexingMap = AffineMap::get (numLoops, 0 , exprs, packOp.getContext ());
104
150
151
+ // Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op.
152
+ // TODO: should we propagate the permutation of outer dims to the pack op?
105
153
SmallVector<int64_t > outerDimsPerm;
106
- for (auto outDim : packOp.getOuterDimsPerm ()) {
107
- if (!iterMapToDim.count (outDim))
108
- continue ;
109
- outerDimsPerm.push_back (iterMapToDim[outDim]);
154
+ if (!packInfo.outerDimsOnDomainPerm .empty ()) {
155
+ SmallVector<int64_t > inversedOuterPerm =
156
+ invertPermutationVector (packInfo.outerDimsOnDomainPerm );
157
+ for (auto i : llvm::seq<unsigned >(0 , origIndexingMap.getNumResults ())) {
158
+ int64_t dimPos = exprs[i].cast <AffineDimExpr>().getPosition ();
159
+ exprs[i] = b.getAffineDimExpr (inversedOuterPerm[dimPos]);
160
+ }
110
161
}
162
+ auto indexingMap = AffineMap::get (numLoops, 0 , exprs, b.getContext ());
111
163
112
164
// The operand does not have dimensions that relates to pack op.
113
- if (innerDimsPos.empty () && outerDimsPerm. empty () )
165
+ if (innerDimsPos.empty ())
114
166
return std::make_tuple (opOperand->get (), indexingMap);
115
167
116
168
auto empty = tensor::PackOp::createDestinationTensor (
117
169
b, loc, opOperand->get (), innerTileSizes, innerDimsPos, outerDimsPerm);
118
170
auto packedOperand = b.create <tensor::PackOp>(
119
171
loc, opOperand->get (), empty, innerDimsPos, innerTileSizes,
120
- packOp. getPaddingValue () , outerDimsPerm);
172
+ packInfo. paddingValue , outerDimsPerm);
121
173
return std::make_tuple (packedOperand, indexingMap);
122
174
}
123
175
@@ -187,34 +239,45 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
187
239
return failure ();
188
240
189
241
OpOperand *opOperand = genericOp.getDpsInitOperand (0 );
190
- // TODO: Add support for all permutation indexing maps.
191
- if (! genericOp.getMatchingIndexingMap (opOperand). isIdentity ())
192
- return rewriter. notifyMatchFailure (
193
- packOp, " the result of generic op does not have identity indexing_map " );
242
+ auto packInfo = getPackingInfoFromConsumer (
243
+ genericOp.getMatchingIndexingMap (opOperand), packOp. getMixedTiles (),
244
+ packOp. getInnerDimsPos (), packOp. getOuterDimsPerm (),
245
+ packOp. getPaddingValue () );
194
246
195
247
Location loc = packOp.getLoc ();
196
248
SmallVector<Value> inputOperands;
197
249
SmallVector<AffineMap> indexingMaps;
198
250
for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
199
251
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand (
200
- rewriter, loc, packOp , genericOp, inputOperand);
252
+ rewriter, loc, packInfo , genericOp, inputOperand);
201
253
inputOperands.push_back (packedOperand);
202
254
indexingMaps.push_back (packedIndexingMap);
203
255
}
204
256
205
257
int64_t numLoops = genericOp.getNumLoops ();
206
- int64_t numInnerLoops = packOp. getInnerDimsPos (). size ();
258
+ int64_t numInnerLoops = packInfo. getNumTiledLoops ();
207
259
int64_t newNumLoops = numLoops + numInnerLoops;
208
260
SmallVector<utils::IteratorType> iterTypes =
209
261
genericOp.getIteratorTypesArray ();
210
262
iterTypes.append (numInnerLoops, utils::IteratorType::parallel);
211
263
264
+ // Rebuild the indexing map for the corresponding init operand.
265
+ auto [packedOutOperand, packedOutIndexingMap] =
266
+ getOrCreatePackedViewOfOperand (rewriter, loc, packInfo, genericOp,
267
+ opOperand);
212
268
SmallVector<AffineExpr> outExprs (
213
- genericOp.getMatchingIndexingMap (opOperand).getResults ());
269
+ packedOutIndexingMap.getResults ().drop_back (numInnerLoops));
270
+ // Apply transpose to the indexing map, because we'll replace the init operand
271
+ // with the destination of pack op.
272
+ auto outerDimsPerm = packOp.getOuterDimsPerm ();
273
+ if (!outerDimsPerm.empty ()) {
274
+ applyPermutationToVector<AffineExpr>(outExprs, outerDimsPerm);
275
+ }
214
276
for (int i = 0 ; i < numInnerLoops; ++i)
215
277
outExprs.push_back (rewriter.getAffineDimExpr (numLoops + i));
216
- indexingMaps.push_back (
217
- AffineMap::get (newNumLoops, 0 , outExprs, rewriter.getContext ()));
278
+ AffineMap outMap =
279
+ AffineMap::get (newNumLoops, 0 , outExprs, rewriter.getContext ());
280
+ indexingMaps.push_back (outMap);
218
281
219
282
auto newGenericOp = rewriter.create <linalg::GenericOp>(
220
283
loc, packOp.getDestType (), inputOperands, packOp.getDest (), indexingMaps,
0 commit comments