@@ -82,96 +82,6 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
82
82
return std::make_tuple (res, loopIndexToRangeIndex);
83
83
}
84
84
85
- // IndexedGenericOp explicitly uses induction variables in the loop body. The
86
- // values of the indices that are used in the loop body for any given access of
87
- // input/output memref before `subview` op was applied should be invariant with
88
- // respect to tiling.
89
- //
90
- // Therefore, if the operation is tiled, we have to transform the indices
91
- // accordingly, i.e. offset them by the values of the corresponding induction
92
- // variables that are captured implicitly in the body of the op.
93
- //
94
- // Example. `linalg.indexed_generic` before tiling:
95
- //
96
- // #id_2d = (i, j) -> (i, j)
97
- // #pointwise_2d_trait = {
98
- // indexing_maps = [#id_2d, #id_2d],
99
- // iterator_types = ["parallel", "parallel"],
100
- // n_views = [1, 1]
101
- // }
102
- // linalg.indexed_generic #pointwise_2d_trait %operand, %result {
103
- // ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
104
- // <some operations that use %i, %j>
105
- // }: memref<50x100xf32>, memref<50x100xf32>
106
- //
107
- // After tiling pass with tiles sizes 10 and 25:
108
- //
109
- // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
110
- //
111
- // %c1 = constant 1 : index
112
- // %c0 = constant 0 : index
113
- // %c25 = constant 25 : index
114
- // %c10 = constant 10 : index
115
- // operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
116
- // operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
117
- // scf.for %k = %c0 to operand_dim_0 step %c10 {
118
- // scf.for %l = %c0 to operand_dim_1 step %c25 {
119
- // %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
120
- // : memref<50x100xf32> to memref<?x?xf32, #strided>
121
- // %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1]
122
- // : memref<50x100xf32> to memref<?x?xf32, #strided>
123
- // linalg.indexed_generic pointwise_2d_trait %4, %5 {
124
- // ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
125
- // // Indices `k` and `l` are implicitly captured in the body.
126
- // %transformed_i = addi %i, %k : index // index `i` is offset by %k
127
- // %transformed_j = addi %j, %l : index // index `j` is offset by %l
128
- // // Every use of %i, %j is replaced with %transformed_i, %transformed_j
129
- // <some operations that use %transformed_i, %transformed_j>
130
- // }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
131
- // }
132
- // }
133
- //
134
- // TODO: Investigate whether mixing implicit and explicit indices
135
- // does not lead to losing information.
136
- static void transformIndexedGenericOpIndices (
137
- OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
138
- const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
139
- auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation ());
140
- if (!indexedGenericOp)
141
- return ;
142
-
143
- // `linalg.indexed_generic` comes in two flavours. One has a region with a
144
- // single block that defines the loop body. The other has a `fun` attribute
145
- // that refers to an existing function symbol. The `fun` function call will be
146
- // inserted in the loop body in that case.
147
- //
148
- // TODO: Add support for `linalg.indexed_generic` with `fun` attribute.
149
- auto ®ion = indexedGenericOp.region ();
150
- if (region.empty ()) {
151
- indexedGenericOp.emitOpError (" expected a region" );
152
- return ;
153
- }
154
- auto &block = region.front ();
155
-
156
- OpBuilder::InsertionGuard g (b);
157
- b.setInsertionPointToStart (&block);
158
- for (unsigned i = 0 ; i < indexedGenericOp.getNumLoops (); ++i) {
159
- auto rangeIndex = loopIndexToRangeIndex.find (i);
160
- if (rangeIndex == loopIndexToRangeIndex.end ())
161
- continue ;
162
- Value oldIndex = block.getArgument (i);
163
- // Offset the index argument `i` by the value of the corresponding induction
164
- // variable and replace all uses of the previous value.
165
- Value newIndex = b.create <AddIOp>(indexedGenericOp.getLoc (), oldIndex,
166
- ivs[rangeIndex->second ]);
167
- for (auto &use : oldIndex.getUses ()) {
168
- if (use.getOwner () == newIndex.getDefiningOp ())
169
- continue ;
170
- use.set (newIndex);
171
- }
172
- }
173
- }
174
-
175
85
// All indices returned by IndexOp should be invariant with respect to tiling.
176
86
// Therefore, if an operation is tiled, we have to transform the indices
177
87
// accordingly, i.e. offset them by the values of the corresponding induction
@@ -261,6 +171,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
261
171
if (llvm::all_of (tileSizes, isZero))
262
172
return llvm::None;
263
173
174
+ // Canonicalize indexed generic operations before tiling.
175
+ if (isa<IndexedGenericOp>(op))
176
+ return llvm::None;
177
+
264
178
if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation ())) {
265
179
// For conv op only support tiling along batch dimension (which is the first
266
180
// loop).
@@ -376,9 +290,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
376
290
},
377
291
options.distribution );
378
292
379
- // 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
380
- transformIndexedGenericOpIndices (b, res, ivs, loopIndexToRangeIndex);
381
- // 3b. Transform IndexOp results w.r.t. the tiling.
293
+ // 3. Transform IndexOp results w.r.t. the tiling.
382
294
transformIndexOps (b, res, ivs, loopIndexToRangeIndex);
383
295
384
296
// 4. Gather the newly created loops and return them with the new op.
@@ -521,7 +433,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
521
433
// / Populate the given list with patterns that apply Linalg tiling.
522
434
static void insertTilingPatterns (RewritePatternSet &patterns,
523
435
const LinalgTilingOptions &options) {
524
- RewritePatternList<GenericOp, IndexedGenericOp,
436
+ RewritePatternList<GenericOp,
525
437
#define GET_OP_LIST
526
438
#include " mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
527
439
>::insert (patterns, options);
0 commit comments