@@ -250,9 +250,9 @@ template <typename LinalgOpTy>
250
250
struct LinalgOpPartialReductionInterface
251
251
: public PartialReductionOpInterface::ExternalModel<
252
252
LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
253
- FailureOr<Operation * > generateInitialTensorForPartialReduction (
254
- Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes ,
255
- ArrayRef<int > reductionDims) const {
253
+ FailureOr<Value > generateInitialTensorForPartialReduction (
254
+ Operation *op, OpBuilder &b, Location loc, int64_t resultNumber ,
255
+ ArrayRef<OpFoldResult> sizes, ArrayRef< int > reductionDims) const {
256
256
auto linalgOp = cast<LinalgOp>(op);
257
257
OpBuilder::InsertionGuard guard (b);
258
258
@@ -262,7 +262,8 @@ struct LinalgOpPartialReductionInterface
262
262
// loops. This could be controlled by user for more flexibility.
263
263
264
264
SmallVector<Operation *, 4 > combinerOps;
265
- if (!matchReduction (linalgOp.getRegionOutputArgs (), 0 , combinerOps) ||
265
+ if (!matchReduction (linalgOp.getRegionOutputArgs (), resultNumber,
266
+ combinerOps) ||
266
267
combinerOps.size () != 1 )
267
268
return op->emitOpError (" Failed to anaysis the reduction operation." );
268
269
@@ -273,7 +274,7 @@ struct LinalgOpPartialReductionInterface
273
274
" Failed to get an identity value for the reduction operation." );
274
275
275
276
ArrayRef<int64_t > oldShape =
276
- linalgOp.getShape (linalgOp.getDpsInitOperand (0 ));
277
+ linalgOp.getShape (linalgOp.getDpsInitOperand (resultNumber ));
277
278
278
279
// Calculate the new shape, we insert the new dimensions based on the index
279
280
// of the reduction dimensions.
@@ -293,15 +294,15 @@ struct LinalgOpPartialReductionInterface
293
294
newOutputShape.push_back (dim);
294
295
if (ShapedType::isDynamic (dim))
295
296
dynamicDims.push_back (b.create <tensor::DimOp>(
296
- loc, linalgOp.getDpsInitOperand (0 )->get (), oldIdx));
297
+ loc, linalgOp.getDpsInitOperand (resultNumber )->get (), oldIdx));
297
298
}
298
299
Value emptyTensor = b.create <tensor::EmptyOp>(
299
- loc, newOutputShape, linalgOp. getRegionOutputArgs ()[ 0 ]. getType (),
300
- dynamicDims);
300
+ loc, newOutputShape,
301
+ linalgOp. getRegionOutputArgs ()[resultNumber]. getType (), dynamicDims);
301
302
Value constantOp = b.create <arith::ConstantOp>(loc, *identity);
302
303
auto identityTensor =
303
304
b.create <linalg::FillOp>(loc, constantOp, emptyTensor);
304
- return identityTensor.getOperation ( );
305
+ return identityTensor.getResult ( 0 );
305
306
}
306
307
307
308
Operation *tileToPartialReduction (Operation *op, OpBuilder &b, Location loc,
@@ -312,44 +313,64 @@ struct LinalgOpPartialReductionInterface
312
313
OpBuilder::InsertionGuard guard (b);
313
314
auto linalgOp = cast<LinalgOp>(op);
314
315
315
- AffineMap oldOutputMap =
316
- linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (0 ));
317
- SmallVector<AffineExpr> outputExpr (oldOutputMap.getNumResults () +
318
- reductionDims.size ());
319
-
320
- for (int idx : reductionDims)
321
- outputExpr[idx] = b.getAffineDimExpr (idx);
322
- int currExpr = 0 ;
323
- for (int idx : llvm::seq<int >(0 , outputExpr.size ())) {
324
- if (outputExpr[idx])
325
- continue ;
326
- outputExpr[idx] = oldOutputMap.getResult (currExpr++);
316
+ // Step 1. Extend init maps to have reduction dimension dims, since we
317
+ // are converting them to parallel dimensions.
318
+ SmallVector<AffineMap> newInitMaps;
319
+ newInitMaps.reserve (linalgOp.getNumDpsInits ());
320
+ for (int idx : llvm::seq<int >(0 , linalgOp.getNumDpsInits ())) {
321
+ // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
322
+ // this with a for range loop when we have it.
323
+ AffineMap newMap =
324
+ linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (idx));
325
+ for (int redPos : reductionDims) {
326
+ newMap = newMap.insertResult (b.getAffineDimExpr (redPos),
327
+ newMap.getNumResults ());
328
+ }
329
+ newInitMaps.push_back (newMap);
327
330
}
328
331
329
- // Step 1: Extract a slice of the input operands.
330
- SmallVector<Value> valuesToTile = linalgOp.getDpsInputs ();
331
- SmallVector<Value, 4 > tiledOperands = makeTiledShapes (
332
- b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true );
332
+ // Step 2a: Extract a slice of the input operands.
333
+ SmallVector<Value, 4 > tiledInputs = makeTiledShapes (
334
+ b, loc, linalgOp, linalgOp.getDpsInputs (), offsets, sizes, {}, true );
335
+
336
+ // Step 2b: Extract a slice of the init operands.
337
+ SmallVector<Value, 1 > tiledInits;
338
+ for (auto [valueMap, valueToTile] : llvm::zip_equal (newInitMaps, init)) {
339
+ int64_t initRank = valueMap.getNumResults ();
340
+ SmallVector<OpFoldResult> initOffset (initRank, b.getIndexAttr (0 ));
341
+ SmallVector<OpFoldResult> initStride (initRank, b.getIndexAttr (1 ));
342
+ SmallVector<OpFoldResult> initSizes;
343
+ for (AffineExpr dimExpr : valueMap.getResults ()) {
344
+ auto dim = cast<AffineDimExpr>(dimExpr);
345
+ initSizes.push_back (sizes[dim.getPosition ()]);
346
+ }
347
+ // TODO: Use SubsetExtractOpInterface here once available.
348
+ auto extractSlice = b.create <tensor::ExtractSliceOp>(
349
+ loc, valueToTile, initOffset, initSizes, initStride);
350
+ tiledInits.push_back (extractSlice);
351
+ }
333
352
334
- // Step 2: Extract the accumulator operands
335
- SmallVector<OpFoldResult> strides (offsets.size (), b.getIndexAttr (1 ));
336
- SmallVector<OpFoldResult> outOffsets (offsets.size (), b.getIndexAttr (0 ));
337
- // TODO: use SubsetExtractOpInterface once it is available.
338
- Value out = b.create <tensor::ExtractSliceOp>(loc, init[0 ], outOffsets,
339
- sizes, strides);
353
+ // Update the indexing maps.
354
+ SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray ();
355
+ // Change the init maps.
356
+ for (int idx : llvm::seq<int >(0 , linalgOp.getNumDpsInits ())) {
357
+ // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
358
+ // this with a for range loop when we have it.
359
+ OpOperand *initOperand = linalgOp.getDpsInitOperand (idx);
360
+ int64_t mapIdx = linalgOp.getIndexingMapIndex (initOperand);
361
+ newMaps[mapIdx] = newInitMaps[idx];
362
+ }
340
363
341
- // Step3. Create a generic op where the reduction dimensions are replaced
342
- // by a parallel dimension of the size of reduction.
364
+ // Step 3. Change the reduction dim iterator types.
343
365
SmallVector<utils::IteratorType> newIteratorTypes =
344
366
linalgOp.getIteratorTypesArray ();
345
367
for (int dim : reductionDims)
346
368
newIteratorTypes[dim] = utils::IteratorType::parallel;
347
- SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray ();
348
- newMaps.back () = AffineMap::get (newMaps.back ().getNumDims (), 0 , outputExpr,
349
- linalgOp.getContext ());
369
+
370
+ // Step 4. Create the new generic op.
350
371
auto genericOp =
351
- b.create <GenericOp>(loc, TypeRange ({out. getType ()}), tiledOperands ,
352
- ValueRange ({out}) , newMaps, newIteratorTypes);
372
+ b.create <GenericOp>(loc, ValueRange (tiledInits). getTypes (), tiledInputs ,
373
+ tiledInits , newMaps, newIteratorTypes);
353
374
IRMapping mapping;
354
375
op->getRegion (0 ).cloneInto (&genericOp.getRegion (),
355
376
genericOp.getRegion ().begin (), mapping);
@@ -361,40 +382,53 @@ struct LinalgOpPartialReductionInterface
361
382
ArrayRef<int > reductionDims) const {
362
383
auto linalgOp = cast<LinalgOp>(op);
363
384
364
- DenseSet<int > reductionDimsSet (reductionDims.begin (), reductionDims.end ());
365
-
366
- // Then create a new reduction that only reduce the newly added dimensions
367
- // from the previous op.
368
- int64_t intermRank = cast<ShapedType>(partialReduce[0 ].getType ()).getRank ();
369
- AffineMap inputMap = b.getMultiDimIdentityMap (intermRank);
370
- SmallVector<utils::IteratorType> reductionIteratorTypes;
371
- SmallVector<AffineExpr> exprs;
372
-
373
- for (int64_t i : llvm::seq<int64_t >(0 , intermRank)) {
374
- if (reductionDimsSet.contains (i)) {
375
- reductionIteratorTypes.push_back (utils::IteratorType::reduction);
376
- } else {
377
- exprs.push_back (b.getAffineDimExpr (i));
378
- reductionIteratorTypes.push_back (utils::IteratorType::parallel);
385
+ // Step 1. Recover the dims that actually need to be merged from the
386
+ // original operation. We can classify the original iterators as follows:
387
+ //
388
+ // parallel --> parallel
389
+ // reduction + not in reductionDims --> parallel (already reduced)
390
+ // reduction + in reductionDims --> reduction (will reduce now)
391
+ SmallVector<utils::IteratorType> iterators (linalgOp.getNumLoops (),
392
+ utils::IteratorType::parallel);
393
+ for (int redIdx : reductionDims)
394
+ iterators[redIdx] = utils::IteratorType::reduction;
395
+
396
+ // Step 2. For each partial result, create a map to index it. This map
397
+ // is simply the indexing map for the original result with reductionDims
398
+ // appended (as produced in tileToPartialReduction).
399
+ int64_t numInits = linalgOp.getNumDpsInits ();
400
+ SmallVector<AffineMap> indexingMaps (numInits * 2 );
401
+ for (int idx : llvm::seq<int >(0 , numInits)) {
402
+ AffineMap &inputMap = indexingMaps[idx];
403
+ AffineMap &outputMap = indexingMaps[numInits + idx];
404
+
405
+ outputMap =
406
+ linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (idx));
407
+ inputMap = outputMap;
408
+ for (int redPos : reductionDims) {
409
+ inputMap = inputMap.insertResult (b.getAffineDimExpr (redPos),
410
+ inputMap.getNumResults ());
379
411
}
380
412
}
381
413
382
- AffineMap outputMap =
383
- AffineMap::get (intermRank, 0 , exprs, op->getContext ());
384
- SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
385
-
386
- SmallVector<Operation *, 4 > combinerOps;
387
- matchReduction (linalgOp.getRegionOutputArgs (), 0 , combinerOps);
388
- Operation *reductionOp = combinerOps[0 ];
389
-
390
414
auto reduction = b.create <GenericOp>(
391
- loc, op->getResultTypes (), ValueRange ({partialReduce[0 ]}),
392
- linalgOp.getDpsInits (), reductionMaps, reductionIteratorTypes,
393
- [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
394
- Operation *clonedReductionOp = b.clone (*reductionOp);
395
- clonedReductionOp->setOperand (0 , inputs[0 ]);
396
- clonedReductionOp->setOperand (1 , inputs[1 ]);
397
- b.create <linalg::YieldOp>(loc, clonedReductionOp->getResult (0 ));
415
+ loc, op->getResultTypes (), partialReduce, linalgOp.getDpsInits (),
416
+ indexingMaps, iterators,
417
+ [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
418
+ int64_t numInits = linalgOp.getNumDpsInits ();
419
+ SmallVector<Value> yieldedValues;
420
+ for (int idx : llvm::seq<int >(0 , numInits)) {
421
+ // Get the combiner op.
422
+ SmallVector<Operation *, 4 > combinerOps;
423
+ matchReduction (linalgOp.getRegionOutputArgs (), idx, combinerOps);
424
+ Operation *clonedReductionOp = b.clone (*combinerOps[0 ]);
425
+ // Combine the input at idx and output at numInits + idx.
426
+ clonedReductionOp->setOperand (0 , inputs[idx]);
427
+ clonedReductionOp->setOperand (1 , inputs[numInits + idx]);
428
+ // Yield.
429
+ yieldedValues.push_back (clonedReductionOp->getResult (0 ));
430
+ }
431
+ b.create <linalg::YieldOp>(loc, yieldedValues);
398
432
});
399
433
return reduction.getOperation ();
400
434
}
0 commit comments