Skip to content

Commit 7594f50

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Cleanup LinalgOp usage in fusion (NFC).
Replace the uses of deprecated Structured Op Interface methods in Fusion.cpp. This patch is based on https://reviews.llvm.org/D103394. Differential Revision: https://reviews.llvm.org/D103437
1 parent c2e5226 commit 7594f50

File tree

1 file changed

+43
-46
lines changed

1 file changed

+43
-46
lines changed

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -69,38 +69,36 @@ struct ShapeDimension {
6969
static ShapeDimension
7070
getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
7171
bool fromSubViewOpOnly = false) {
72-
auto maps = op.indexing_maps();
7372
// Iterate over the inputs and outputs in order.
7473
// Extract the subranges from the linearized ranges.
75-
for (auto en : llvm::enumerate(op.getShapedOperands())) {
74+
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
7675
// The method `getRangeFromOperandShape` requires using SubViewOp or
7776
// SubTensorOps. If the value isnt defined from there continue.
7877
// todo: The method should be adapted to get the values from
7978
// `ViewInterface`. The interface needs a `getOrCreateRanges` method which
8079
// currently returns a `linalg.range`. The fix here is to move this op to
8180
// `std` dialect and add the method to `ViewInterface`.
8281
if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
83-
en.value().getDefiningOp()))
82+
opOperand->get().getDefiningOp()))
8483
continue;
8584

86-
unsigned idx = en.index();
87-
auto map = maps[idx].cast<AffineMapAttr>().getValue();
88-
LLVM_DEBUG(llvm::dbgs()
89-
<< "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
85+
AffineMap map = op.getTiedIndexingMap(opOperand);
86+
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
87+
<< opOperand->getOperandNumber() << "\n");
9088
LLVM_DEBUG(llvm::dbgs()
9189
<< "getShapeDefiningLoopRange map: " << map << "\n");
92-
Value shape = en.value();
9390
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
94-
for (auto en2 : llvm::enumerate(map.getResults())) {
95-
auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
91+
for (auto en : llvm::enumerate(map.getResults())) {
92+
auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
9693
if (!dimExpr)
9794
continue;
98-
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
95+
if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) {
9996
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
10097
<< loopDepth << "\n");
101-
LLVM_DEBUG(llvm::dbgs()
102-
<< "getShapeDefiningLoopRange shape: " << shape << "\n");
103-
return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
98+
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
99+
<< opOperand->get() << "\n");
100+
return ShapeDimension{opOperand->get(),
101+
static_cast<unsigned>(en.index())};
104102
}
105103
}
106104
}
@@ -122,26 +120,24 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
122120
// would need to add the intermediate results to `linalg.yield`. After that a
123121
// canonicalization pass would move the unused output args of the `tiled_loop`
124122
// to the `input` section.
125-
static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
123+
static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
126124
auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
127125
if (!tiledLoop)
128-
return llvm::to_vector<4>(producer.getShapedOperands());
126+
return producer.getInputAndOutputOperands();
129127

130-
SmallVector<Value, 4> tiledOperands;
128+
SmallVector<Value> tiledOperands;
131129
assert(producer.hasTensorSemantics() &&
132130
"only fusion on tensors is currently supported for TiledLinalgOp");
133131

134-
for (auto producerInput : producer.getInputTensors()) {
135-
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
132+
for (OpOperand *producerInput : producer.getInputTensorOperands()) {
133+
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
136134
if (addedInput == nullptr)
137-
addedInput = &tiledLoop.appendInputOperand(b, producerInput);
135+
addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
138136
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
139137
tiledOperands.push_back(addedBlockArg);
140138
}
141-
for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
142-
Value producerOutput = en.value();
143-
144-
Value result = producer->getResult(en.index());
139+
for (OpOperand *producerOutput : producer.getOutputTensorOperands()) {
140+
OpResult result = producer.getTiedOpResult(producerOutput);
145141
OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
146142
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
147143
assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
@@ -152,10 +148,11 @@ static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
152148
int opNumber = isInput ? resultInputOperand->getOperandNumber()
153149
: resultOutputOperand->getOperandNumber();
154150

155-
OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
151+
OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get());
156152
if (addedOutput == nullptr)
157-
addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
158-
: &tiledLoop.appendOutputOperand(b, producerOutput);
153+
addedOutput =
154+
isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get())
155+
: &tiledLoop.appendOutputOperand(b, producerOutput->get());
159156

160157
OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
161158
auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
@@ -200,7 +197,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
200197
}
201198

202199
SmallVector<Value, 8> clonedShapes;
203-
clonedShapes.reserve(producer.getNumShapedOperands());
200+
clonedShapes.reserve(producer.getNumInputsAndOutputs());
204201

205202
// Compute subranges for all tensor input/output operands.
206203
clonedShapes.append(makeTiledShapes(b, loc, producer,
@@ -267,16 +264,9 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
267264
llvm_unreachable("SubviewOp or SubTensorOp expected");
268265
}
269266

270-
/// Fuses the producer of `producerIdx` into the loop immediately enclosing
271-
/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
272-
/// is needed just before the `consumer.
273-
///
274-
/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
275-
/// 2 cases:
276-
/// 1. Buffer case: `producerIdx` is the index of the buffer in
277-
/// `producer.getOutputBuffers()`.
278-
/// 2. Tensor case: `producerIdx` is the index of the tensor in
279-
/// `producer.getResults()`.
267+
/// Fuses the producer into the loop immediately enclosing the consumer.
268+
/// This is achieved by "recomputing" the producer at the time it
269+
/// is needed just before the consumer.
280270
static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
281271
OpOperand &consumerOpOperand) {
282272
LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
@@ -548,9 +538,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
548538
OpBuilder::InsertionGuard g(b);
549539
b.setInsertionPoint(consumerOp);
550540
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
541+
OpOperand *opOperand =
542+
producerOp.getOutputOperand(producerOpResult.getResultNumber());
551543
LinalgOp fusedProducer =
552-
fuse(b, producerOp,
553-
producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()),
544+
fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
554545
consumerOpOperand);
555546

556547
// Replace use.
@@ -770,9 +761,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
770761
FusableOpDependencesTy fusableDependences;
771762
DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
772763
for (LinalgOp op : reverse(ops)) {
773-
for (OpOperand &opOperand : op.getShapedOpOperands()) {
764+
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
774765
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
775-
fusableDependence = findFusableProducer(opOperand, dependenceGraph);
766+
fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
776767
if (!fusableDependence)
777768
continue;
778769
// Canonicalize indexed generic ops before fusion.
@@ -905,10 +896,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
905896
// To keep the second type of information while letting the unfused op die
906897
// unused, we need to forward the producer output operand.
907898
if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
908-
for (auto &operand : forOp.getIterOpOperands())
909-
if (auto opResult = operand.get().dyn_cast<OpResult>())
910-
if (opResult.getOwner() == origOp)
911-
operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
899+
for (auto &operand : forOp.getIterOpOperands()) {
900+
if (auto opResult = operand.get().dyn_cast<OpResult>()) {
901+
if (opResult.getOwner() == origOp) {
902+
Value output =
903+
origOp.getOutputOperand(opResult.getResultNumber())->get();
904+
assert(output.getType().isa<RankedTensorType>());
905+
operand.set(output);
906+
}
907+
}
908+
}
912909
}
913910
}
914911
return fusedOps;

0 commit comments

Comments
 (0)