Skip to content

Commit 8e890c8

Browse files
committed
save
1 parent 1e7d047 commit 8e890c8

File tree

10 files changed

+211
-121
lines changed

10 files changed

+211
-121
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
16811681
// TODO: support mixed static-dynamic (see TileUsingForallOp).
16821682
let arguments = (ins TransformHandleTypeInterface:$target,
16831683
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
1684-
let results = (outs TransformHandleTypeInterface:$fill_op,
1684+
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
16851685
TransformHandleTypeInterface:$split_linalg_op,
16861686
TransformHandleTypeInterface:$combining_linalg_op,
16871687
TransformHandleTypeInterface:$for_op);
@@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
17871787
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
17881788
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
17891789
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
1790-
let results = (outs TransformHandleTypeInterface:$fill_op,
1790+
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
17911791
TransformHandleTypeInterface:$split_linalg_op,
17921792
TransformHandleTypeInterface:$combining_linalg_op,
17931793
TransformHandleTypeInterface:$forall_op);

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
876876
Operation *parallelTiledOp;
877877
/// The final reduction operation merging all the partial reductions.
878878
Operation *mergeOp;
879-
/// The op initializing the tensor used for partial reductions.
880-
Operation *initialOp;
879+
/// Initial values used for partial reductions.
880+
SmallVector<Value> initialValues;
881881
/// The `scf.forall` operation that iterate over the tiles.
882882
scf::ForallOp loops;
883883
};

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
250250
Operation *parallelTiledOp;
251251
/// The final reduction operation merging all the partial reductions.
252252
Operation *mergeOp;
253-
/// Initial op
254-
Operation *initialOp;
253+
/// Initial values used for reduction.
254+
SmallVector<Value> initialValues;
255255
/// The loop operations that iterate over the tiles.
256256
SmallVector<LoopLikeOpInterface> loops;
257257
};

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
170170
operation reduction. The tensor shape is equal to operation result
171171
shape with new dimension for each non zero tile size.
172172
}],
173-
/*retType=*/"FailureOr<Operation*>",
173+
/*retType=*/"FailureOr<Value>",
174174
/*methodName=*/"generateInitialTensorForPartialReduction",
175175
/*args=*/(ins
176176
"OpBuilder &":$b,
177-
"Location ":$loc,
177+
"Location":$loc,
178+
"int64_t":$resultNumber,
178179
"ArrayRef<OpFoldResult>":$sizes,
179180
"ArrayRef<int>":$reductionDim),
180181
/*methodBody=*/"",

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
25232523

25242524
if (failed(result))
25252525
return emitDefaultSilenceableFailure(target);
2526-
results.push_back(result->initialOp);
2526+
for (Value initValue : result->initialValues)
2527+
results.push_back(initValue.getDefiningOp());
25272528
results.push_back(result->parallelTiledOp);
25282529
results.push_back(result->mergeOp);
25292530
results.push_back(result->loops.front());
@@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
25742575
diag.attachNote(target.getLoc()) << "target operation";
25752576
return diag;
25762577
}
2577-
results.push_back(result->initialOp);
2578+
for (Value initValue : result->initialValues)
2579+
results.push_back(initValue.getDefiningOp());
25782580
results.push_back(result->parallelTiledOp);
25792581
results.push_back(result->mergeOp);
25802582
results.push_back(result->loops);

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
155155
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
156156
PartialReductionOpInterface partialReductionIface =
157157
llvm::cast<PartialReductionOpInterface>(op.getOperation());
158-
FailureOr<Operation *> reductionNeutralTensorOp =
158+
assert(op->getNumResults() == 1 && "Multiple results not supported.");
159+
FailureOr<Value> reductionNeutralTensor =
159160
partialReductionIface.generateInitialTensorForPartialReduction(
160-
builder, builder.getLoc(), shape, {});
161-
assert(succeeded(reductionNeutralTensorOp));
162-
builder.create<scf::YieldOp>(
163-
reductionNeutralTensorOp.value()->getResult(0));
161+
builder, builder.getLoc(), 0, shape, {});
162+
assert(succeeded(reductionNeutralTensor));
163+
builder.create<scf::YieldOp>(reductionNeutralTensor.value());
164164
}
165165
return ifOp.getResult(0);
166166
}
@@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
173173
ImplicitLocOpBuilder &builder) {
174174
// TODO: add support for multiple destination passing style initial value
175175
// operands.
176-
// PartialReductionOpInterface::generateInitialTensorForPartialReduction
177-
// needs to also support multiple DPS initial operands.
176+
assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
178177
SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
179178
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
180179
Value spmdizedInitOperand =

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -692,12 +692,17 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
692692
op, "reduction dimension must be mapped to threads");
693693

694694
// 1. Create the inital tensor value.
695-
FailureOr<Operation *> identityTensor =
696-
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
697-
reductionDim);
698-
if (failed(identityTensor))
699-
return b.notifyMatchFailure(op,
700-
"cannot create a tensor of identity value.");
695+
SmallVector<Value> initTensors;
696+
initTensors.reserve(op->getNumResults());
697+
for (int idx : llvm::seq<int>(0, op->getNumResults())) {
698+
FailureOr<Value> initValue = op.generateInitialTensorForPartialReduction(
699+
b, loc, idx, numThreads, reductionDim);
700+
if (failed(initValue))
701+
return b.notifyMatchFailure(
702+
op, "cannot create a tensor of identity value for result " +
703+
std::to_string(idx));
704+
initTensors.push_back(initValue.value());
705+
}
701706

702707
// Gather destination tensors.
703708
SmallVector<Value> dest;
@@ -715,8 +720,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
715720

716721
// 2. Create the ForallOp with an empty region.
717722
scf::ForallOp forallOp = b.create<scf::ForallOp>(
718-
loc, getAsOpFoldResult(materializedNonZeroNumThreads),
719-
(*identityTensor)->getResults(), mapping);
723+
loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
724+
mapping);
720725

721726
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
722727
// be nested under `forallOp`.
@@ -726,7 +731,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
726731
/*nominalTileSizes=*/std::nullopt, tiledOffsets,
727732
tiledSizes);
728733

729-
// 4. Clone the tileable op and update its destination operands to use the
734+
// 4b. Clone the tileable op and update its destination operands to use the
730735
// output bbArgs of the ForallOp.
731736
SmallVector<Value> tilingResults;
732737
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
@@ -838,7 +843,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
838843

839844
// 8. Return.
840845
ForallReductionTilingResult results;
841-
results.initialOp = *identityTensor;
846+
results.initialValues = initTensors;
842847
results.loops = forallOp;
843848
results.parallelTiledOp = tiledOp;
844849
results.mergeOp = mergeOp;

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 102 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ template <typename LinalgOpTy>
250250
struct LinalgOpPartialReductionInterface
251251
: public PartialReductionOpInterface::ExternalModel<
252252
LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
253-
FailureOr<Operation *> generateInitialTensorForPartialReduction(
254-
Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
255-
ArrayRef<int> reductionDims) const {
253+
FailureOr<Value> generateInitialTensorForPartialReduction(
254+
Operation *op, OpBuilder &b, Location loc, int64_t resultNumber,
255+
ArrayRef<OpFoldResult> sizes, ArrayRef<int> reductionDims) const {
256256
auto linalgOp = cast<LinalgOp>(op);
257257
OpBuilder::InsertionGuard guard(b);
258258

@@ -262,7 +262,8 @@ struct LinalgOpPartialReductionInterface
262262
// loops. This could be controlled by user for more flexibility.
263263

264264
SmallVector<Operation *, 4> combinerOps;
265-
if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
265+
if (!matchReduction(linalgOp.getRegionOutputArgs(), resultNumber,
266+
combinerOps) ||
266267
combinerOps.size() != 1)
267268
return op->emitOpError("Failed to anaysis the reduction operation.");
268269

@@ -273,7 +274,7 @@ struct LinalgOpPartialReductionInterface
273274
"Failed to get an identity value for the reduction operation.");
274275

275276
ArrayRef<int64_t> oldShape =
276-
linalgOp.getShape(linalgOp.getDpsInitOperand(0));
277+
linalgOp.getShape(linalgOp.getDpsInitOperand(resultNumber));
277278

278279
// Calculate the new shape, we insert the new dimensions based on the index
279280
// of the reduction dimensions.
@@ -293,15 +294,15 @@ struct LinalgOpPartialReductionInterface
293294
newOutputShape.push_back(dim);
294295
if (ShapedType::isDynamic(dim))
295296
dynamicDims.push_back(b.create<tensor::DimOp>(
296-
loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
297+
loc, linalgOp.getDpsInitOperand(resultNumber)->get(), oldIdx));
297298
}
298299
Value emptyTensor = b.create<tensor::EmptyOp>(
299-
loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
300-
dynamicDims);
300+
loc, newOutputShape,
301+
linalgOp.getRegionOutputArgs()[resultNumber].getType(), dynamicDims);
301302
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
302303
auto identityTensor =
303304
b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
304-
return identityTensor.getOperation();
305+
return identityTensor.getResult(0);
305306
}
306307

307308
Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
@@ -312,44 +313,64 @@ struct LinalgOpPartialReductionInterface
312313
OpBuilder::InsertionGuard guard(b);
313314
auto linalgOp = cast<LinalgOp>(op);
314315

315-
AffineMap oldOutputMap =
316-
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
317-
SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
318-
reductionDims.size());
319-
320-
for (int idx : reductionDims)
321-
outputExpr[idx] = b.getAffineDimExpr(idx);
322-
int currExpr = 0;
323-
for (int idx : llvm::seq<int>(0, outputExpr.size())) {
324-
if (outputExpr[idx])
325-
continue;
326-
outputExpr[idx] = oldOutputMap.getResult(currExpr++);
316+
// Step 1. Extend init maps to have reduction dimension dims, since we
317+
// are converting them to parallel dimensions.
318+
SmallVector<AffineMap> newInitMaps;
319+
newInitMaps.reserve(linalgOp.getNumDpsInits());
320+
for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
321+
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
322+
// this with a for range loop when we have it.
323+
AffineMap newMap =
324+
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
325+
for (int redPos : reductionDims) {
326+
newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
327+
newMap.getNumResults());
328+
}
329+
newInitMaps.push_back(newMap);
327330
}
328331

329-
// Step 1: Extract a slice of the input operands.
330-
SmallVector<Value> valuesToTile = linalgOp.getDpsInputs();
331-
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
332-
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
332+
// Step 2a: Extract a slice of the input operands.
333+
SmallVector<Value, 4> tiledInputs = makeTiledShapes(
334+
b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
335+
336+
// Step 2b: Extract a slice of the init operands.
337+
SmallVector<Value, 1> tiledInits;
338+
for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
339+
int64_t initRank = valueMap.getNumResults();
340+
SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
341+
SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
342+
SmallVector<OpFoldResult> initSizes;
343+
for (AffineExpr dimExpr : valueMap.getResults()) {
344+
auto dim = cast<AffineDimExpr>(dimExpr);
345+
initSizes.push_back(sizes[dim.getPosition()]);
346+
}
347+
// TODO: Use SubsetExtractOpInterface here once available.
348+
auto extractSlice = b.create<tensor::ExtractSliceOp>(
349+
loc, valueToTile, initOffset, initSizes, initStride);
350+
tiledInits.push_back(extractSlice);
351+
}
333352

334-
// Step 2: Extract the accumulator operands
335-
SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
336-
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
337-
// TODO: use SubsetExtractOpInterface once it is available.
338-
Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
339-
sizes, strides);
353+
// Update the indexing maps.
354+
SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
355+
// Change the init maps.
356+
for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
357+
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
358+
// this with a for range loop when we have it.
359+
OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
360+
int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
361+
newMaps[mapIdx] = newInitMaps[idx];
362+
}
340363

341-
// Step3. Create a generic op where the reduction dimensions are replaced
342-
// by a parallel dimension of the size of reduction.
364+
// Step 3. Change the reduction dim iterator types.
343365
SmallVector<utils::IteratorType> newIteratorTypes =
344366
linalgOp.getIteratorTypesArray();
345367
for (int dim : reductionDims)
346368
newIteratorTypes[dim] = utils::IteratorType::parallel;
347-
SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
348-
newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
349-
linalgOp.getContext());
369+
370+
// Step 4. Create the new generic op.
350371
auto genericOp =
351-
b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
352-
ValueRange({out}), newMaps, newIteratorTypes);
372+
b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
373+
tiledInits, newMaps, newIteratorTypes);
353374
IRMapping mapping;
354375
op->getRegion(0).cloneInto(&genericOp.getRegion(),
355376
genericOp.getRegion().begin(), mapping);
@@ -361,40 +382,53 @@ struct LinalgOpPartialReductionInterface
361382
ArrayRef<int> reductionDims) const {
362383
auto linalgOp = cast<LinalgOp>(op);
363384

364-
DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
365-
366-
// Then create a new reduction that only reduce the newly added dimensions
367-
// from the previous op.
368-
int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
369-
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
370-
SmallVector<utils::IteratorType> reductionIteratorTypes;
371-
SmallVector<AffineExpr> exprs;
372-
373-
for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
374-
if (reductionDimsSet.contains(i)) {
375-
reductionIteratorTypes.push_back(utils::IteratorType::reduction);
376-
} else {
377-
exprs.push_back(b.getAffineDimExpr(i));
378-
reductionIteratorTypes.push_back(utils::IteratorType::parallel);
385+
// Step 1. Recover the dims that actually need to be merged from the
386+
// original operation. We can classify the original iterators as follows:
387+
//
388+
// parallel --> parallel
389+
// reduction + not in reductionDims --> parallel (already reduced)
390+
// reduction + in reductionDims --> reduction (will reduce now)
391+
SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
392+
utils::IteratorType::parallel);
393+
for (int redIdx : reductionDims)
394+
iterators[redIdx] = utils::IteratorType::reduction;
395+
396+
// Step 2. For each partial result, create a map to index it. This map
397+
// is simply the indexing map for the original result with reductionDims
398+
// appended (as produced in tileToPartialReduction).
399+
int64_t numInits = linalgOp.getNumDpsInits();
400+
SmallVector<AffineMap> indexingMaps(numInits * 2);
401+
for (int idx : llvm::seq<int>(0, numInits)) {
402+
AffineMap &inputMap = indexingMaps[idx];
403+
AffineMap &outputMap = indexingMaps[numInits + idx];
404+
405+
outputMap =
406+
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
407+
inputMap = outputMap;
408+
for (int redPos : reductionDims) {
409+
inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
410+
inputMap.getNumResults());
379411
}
380412
}
381413

382-
AffineMap outputMap =
383-
AffineMap::get(intermRank, 0, exprs, op->getContext());
384-
SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
385-
386-
SmallVector<Operation *, 4> combinerOps;
387-
matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps);
388-
Operation *reductionOp = combinerOps[0];
389-
390414
auto reduction = b.create<GenericOp>(
391-
loc, op->getResultTypes(), ValueRange({partialReduce[0]}),
392-
linalgOp.getDpsInits(), reductionMaps, reductionIteratorTypes,
393-
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
394-
Operation *clonedReductionOp = b.clone(*reductionOp);
395-
clonedReductionOp->setOperand(0, inputs[0]);
396-
clonedReductionOp->setOperand(1, inputs[1]);
397-
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
415+
loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
416+
indexingMaps, iterators,
417+
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
418+
int64_t numInits = linalgOp.getNumDpsInits();
419+
SmallVector<Value> yieldedValues;
420+
for (int idx : llvm::seq<int>(0, numInits)) {
421+
// Get the combiner op.
422+
SmallVector<Operation *, 4> combinerOps;
423+
matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
424+
Operation *clonedReductionOp = b.clone(*combinerOps[0]);
425+
// Combine the input at idx and output at numInits + idx.
426+
clonedReductionOp->setOperand(0, inputs[idx]);
427+
clonedReductionOp->setOperand(1, inputs[numInits + idx]);
428+
// Yield.
429+
yieldedValues.push_back(clonedReductionOp->getResult(0));
430+
}
431+
b.create<linalg::YieldOp>(loc, yieldedValues);
398432
});
399433
return reduction.getOperation();
400434
}

0 commit comments

Comments
 (0)