Skip to content

Commit 8c427a4

Browse files
author
git apple-llvm automerger
committed
Merge commit '518e6f341ddd' from llvm.org/main into apple/main
2 parents a2b4805 + 518e6f3 commit 8c427a4

File tree

3 files changed

+170
-79
lines changed

3 files changed

+170
-79
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
896896
/*desc=*/[{
897897
Return the indexing maps within the current operation.
898898
}],
899-
/*retTy=*/"SmallVector<AffineMap, 4>",
899+
/*retTy=*/"SmallVector<AffineMap>",
900900
/*methodName=*/"getIndexingMaps",
901901
/*args=*/(ins),
902902
/*methodBody=*/"",
@@ -931,6 +931,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
931931
return getIndexingMaps()[i];
932932
}]
933933
>,
934+
InterfaceMethod<
935+
/*desc=*/[{
936+
Return the input indexing maps.
937+
}],
938+
/*retTy=*/"SmallVector<AffineMap>",
939+
/*methodName=*/"getInputIndexingMaps",
940+
/*args=*/(ins),
941+
/*methodBody=*/"",
942+
/*defaultImplementation=*/[{
943+
auto maps = $_op.getIndexingMaps();
944+
return SmallVector<AffineMap>{maps.begin(),
945+
maps.begin() + $_op.getNumInputs()};
946+
}]
947+
>,
934948
InterfaceMethod<
935949
/*desc=*/[{
936950
Return the output indexing map at index `i`.
@@ -944,6 +958,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
944958
return getIndexingMaps()[i + $_op.getNumInputs()];
945959
}]
946960
>,
961+
InterfaceMethod<
962+
/*desc=*/[{
963+
Return the output indexing maps.
964+
}],
965+
/*retTy=*/"SmallVector<AffineMap>",
966+
/*methodName=*/"getOutputIndexingMaps",
967+
/*args=*/(ins),
968+
/*methodBody=*/"",
969+
/*defaultImplementation=*/[{
970+
auto maps = $_op.getIndexingMaps();
971+
return SmallVector<AffineMap>{maps.begin() + $_op.getNumInputs(),
972+
maps.begin() + $_op.getNumShapedOperands()};
973+
}]
974+
>,
947975
InterfaceMethod<
948976
/*desc=*/[{
949977
Return whether the op has only MemRef input and outputs.

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 105 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -61,36 +61,34 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
6161
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
6262
/// the `producer` to use in the fused operation given the indexing map of the
6363
/// result of the producer in the consumer.
64-
static void getIndexingMapOfProducerOperandsInFusedOp(
65-
LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
66-
SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
64+
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
65+
OpOperand &producerOpOperand, AffineMap producerResultIndexMap,
66+
AffineMap fusedConsumerArgIndexMap) {
6767
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
6868
// from consumer loop -> consumer arg tensor index/producer result tensor
6969
// index. The fused loop is same as the consumer loop. For each producer arg
7070
// the indexing map to be computed is a map from consumer loop -> producer
7171
// arg tensor index.
72-
73-
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
7472
// producerResultIndexMap is a map from producer loop -> tensor index.
7573
// Compute the inverse to get map from tensor index -> producer loop.
7674
// The inverse is a map from producer result tensor index -> producer loop.
7775
AffineMap invProducerResultIndexMap =
7876
inversePermutation(producerResultIndexMap);
7977
assert(invProducerResultIndexMap &&
8078
"expected producer result indexig map to be invertible");
81-
for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
82-
// argMap is a map from producer loop -> producer arg tensor index.
83-
AffineMap argMap = producer.getInputIndexingMap(argNum);
84-
85-
// Compose argMap with invProducerResultIndexMap to get a map from
86-
// producer result tensor index -> producer arg tensor index.
87-
AffineMap t1 = argMap.compose(invProducerResultIndexMap);
88-
89-
// Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
90-
// consumer loop/ fused loop -> producer arg tensor index.
91-
AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
92-
fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
93-
}
79+
80+
LinalgOp producer = cast<LinalgOp>(producerOpOperand.getOwner());
81+
// argMap is a map from producer loop -> producer arg tensor index.
82+
AffineMap argMap =
83+
producer.getIndexingMap(producerOpOperand.getOperandNumber());
84+
85+
// Compose argMap with invProducerResultIndexMap to get a map from
86+
// producer result tensor index -> producer arg tensor index.
87+
AffineMap t1 = argMap.compose(invProducerResultIndexMap);
88+
89+
// Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
90+
// consumer loop/ fused loop -> producer arg tensor index.
91+
return t1.compose(fusedConsumerArgIndexMap);
9492
}
9593

9694
/// Generate the region of the fused tensor operation. The region of the fused
@@ -163,6 +161,18 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
163161
.drop_front(numProducerIndices)
164162
.take_front(producer.getNumInputs()))
165163
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
164+
165+
// 4.b. Producer output operand/map that is fused needs to be mapped to the
166+
// producer bbArg if it is an "initTensor" (i.e. its value is actually read).
167+
assert(producer->getNumResults() == 1 && "expected single result producer");
168+
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
169+
BlockArgument bbArg =
170+
producerBlock.getArguments()
171+
.drop_front(numConsumerIndices + producer.getNumInputs())
172+
// TODO: bbArg index of
173+
.front();
174+
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
175+
}
166176
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
167177
for (BlockArgument bbArg : consumerBlock.getArguments()
168178
.drop_front(numConsumerIndices)
@@ -221,73 +231,90 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
221231
!controlFn(producer->getResult(0), consumerOpOperand))
222232
return llvm::None;
223233

224-
unsigned numFusedOperands =
225-
producer.getNumInputs() + consumer.getNumInputs() - 1;
226-
227-
// Compute the fused operands list,
228-
SmallVector<Value, 2> fusedOperands;
229-
fusedOperands.reserve(numFusedOperands);
230-
auto consumerOperands = consumer.getInputs();
231-
auto producerOperands = producer.getInputs();
232-
fusedOperands.assign(consumerOperands.begin(),
233-
std::next(consumerOperands.begin(), consumerIdx));
234-
fusedOperands.append(producerOperands.begin(), producerOperands.end());
235-
fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
236-
consumerOperands.end());
237-
238-
// Compute indexing_maps for the fused operation. The indexing_maps for the
239-
// operands of the consumers that aren't fused are the same. The
240-
// indexing_maps for the producers need to be computed based on the
241-
// indexing_map of the operand at consumerIdx in the consumer.
242-
SmallVector<Attribute, 4> fusedIndexMaps;
243-
auto consumerIndexMaps = consumer.indexing_maps();
244-
fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs());
245-
fusedIndexMaps.assign(consumerIndexMaps.begin(),
246-
std::next(consumerIndexMaps.begin(), consumerIdx));
247-
// Compute indexing maps for the producer args in the fused operation.
248-
getIndexingMapOfProducerOperandsInFusedOp(
249-
producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
250-
251-
// Append the indexing maps for the remaining consumer operands.
252-
fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
253-
consumerIndexMaps.end());
234+
// TODO: allow fusing the producer of an output operand.
235+
assert(consumerIdx < consumer.getNumInputs() &&
236+
"expected producer of input operand");
237+
238+
// Compute the fused operands list and indexing maps.
239+
SmallVector<Value> fusedOperands;
240+
SmallVector<AffineMap> fusedIndexMaps;
241+
fusedOperands.reserve(producer->getNumOperands() +
242+
consumer->getNumOperands());
243+
fusedIndexMaps.reserve(producer->getNumOperands() +
244+
consumer->getNumOperands());
245+
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
246+
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
247+
llvm::append_range(fusedOperands,
248+
consumer.getInputs().take_front(consumerIdx));
249+
llvm::append_range(
250+
fusedIndexMaps,
251+
ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.take_front(
252+
consumerIdx));
253+
// 4. Splice in producer's input operands/maps.
254+
llvm::append_range(fusedOperands, producer.getInputs());
255+
assert(producer->getNumResults() == 1 && "expected single result producer");
256+
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
257+
for (auto &inputOpOperand : producer.getInputOpOperands()) {
258+
// Compute indexing maps for the producer args in the fused operation.
259+
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
260+
inputOpOperand, producerResultIndexMap,
261+
consumer.getInputIndexingMap(consumerIdx));
262+
fusedIndexMaps.push_back(map);
263+
}
264+
// 4.b. Producer output operand/map that is fused needs to be passed if it is
265+
// an "initTensor" (i.e. its value is actually read).
266+
assert(producer->getNumResults() == 1 && "expected single result producer");
267+
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
268+
llvm::append_range(fusedOperands, producer.getOutputs().take_front());
269+
// Compute indexing maps for the producer args in the fused operation.
270+
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
271+
producer.getOutputOpOperands().front(), producerResultIndexMap,
272+
consumer.getOutputIndexingMap(0));
273+
fusedIndexMaps.push_back(map);
274+
}
275+
// 5. Remaining consumer's input operands/maps (drop past index
276+
// `consumerIdx`).
277+
llvm::append_range(fusedOperands,
278+
consumer.getInputs().drop_front(consumerIdx + 1));
279+
llvm::append_range(
280+
fusedIndexMaps,
281+
ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.drop_front(
282+
consumerIdx + 1));
283+
// 6. All of consumer's output operands (skip operands: added by the builder).
284+
// llvm::append_range(fusedOperands, consumer.getOutputs());
285+
llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps());
286+
// 7. All of producer's output operands/maps except the one fused.
287+
// TODO: allow fusion of multi-result producers.
288+
assert(producer->getNumResults() == 1 && "expected single result producer");
254289

255290
// Generate the fused op.
256-
LinalgOp fusedOp;
291+
Operation *fusedOp;
257292
if (isa<GenericOp>(producer.getOperation()) &&
258293
isa<GenericOp>(consumer.getOperation())) {
259-
fusedOp =
260-
rewriter
261-
.create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(),
262-
/*inputs=*/fusedOperands,
263-
// TODO: handle outputs.
264-
consumer.getOutputs(),
265-
rewriter.getArrayAttr(fusedIndexMaps),
266-
consumer.iterator_types(),
267-
/*doc=*/nullptr,
268-
/*library_call=*/nullptr,
269-
/*sparse=*/nullptr)
270-
.getOperation();
294+
fusedOp = rewriter.create<GenericOp>(
295+
consumer.getLoc(), consumer->getResultTypes(),
296+
/*inputs=*/fusedOperands,
297+
// TODO: handle outputs.
298+
consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
299+
consumer.iterator_types(),
300+
/*doc=*/nullptr,
301+
/*library_call=*/nullptr,
302+
/*sparse=*/nullptr);
271303
} else {
272-
fusedOp =
273-
rewriter
274-
.create<IndexedGenericOp>(
275-
consumer.getLoc(), consumer->getResultTypes(),
276-
/*inputs=*/fusedOperands,
277-
// TODO: handle outputs.
278-
consumer.getOutputs(), rewriter.getArrayAttr(fusedIndexMaps),
279-
consumer.iterator_types(),
280-
/*doc=*/nullptr,
281-
/*library_call=*/nullptr,
282-
/*sparse=*/nullptr)
283-
.getOperation();
304+
fusedOp = rewriter.create<IndexedGenericOp>(
305+
consumer.getLoc(), consumer->getResultTypes(),
306+
/*inputs=*/fusedOperands,
307+
// TODO: handle outputs.
308+
consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
309+
consumer.iterator_types(),
310+
/*doc=*/nullptr,
311+
/*library_call=*/nullptr,
312+
/*sparse=*/nullptr);
284313
}
285314

286315
// Construct an AffineMap from consumer loops to producer loops.
287316
// consumer loop -> tensor index
288317
AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
289-
// producer loop -> tensor index
290-
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
291318
// tensor index -> producer loop
292319
AffineMap invProducerResultIndexMap =
293320
inversePermutation(producerResultIndexMap);
@@ -297,9 +324,9 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
297324
AffineMap consumerToProducerLoopsMap =
298325
invProducerResultIndexMap.compose(consumerResultIndexMap);
299326

300-
generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
301-
consumer, consumerToProducerLoopsMap,
302-
consumerIdx, consumer.getNumLoops());
327+
generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
328+
consumerToProducerLoopsMap, consumerIdx,
329+
consumer.getNumLoops());
303330
return SmallVector<Value, 1>(fusedOp->getResults());
304331
}
305332

mlir/test/Dialect/Linalg/fusion-tensor.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,3 +616,39 @@ func @sigmoid_dynamic_dim(%0: tensor<?x1xf32>) -> tensor<?x1xf32> {
616616
} -> tensor<?x1xf32>
617617
return %2 : tensor<?x1xf32>
618618
}
619+
620+
// -----
621+
622+
func private @compute1(%a: f64) -> f64
623+
func private @compute2(%a: f64, %b: i32) -> i32
624+
625+
// CHECK-LABEL: func @generic_index_op2(
626+
func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tensor<1x8xi32> {
627+
%0 = linalg.generic {
628+
indexing_maps = [affine_map<(i, j) -> (i, j)>],
629+
iterator_types = ["parallel", "parallel"]}
630+
outs(%arg0 : tensor<1x8xf64>) {
631+
^bb0(%a: f64):
632+
%r = call @compute1(%a) : (f64) -> f64
633+
linalg.yield %r : f64
634+
} -> tensor<1x8xf64>
635+
636+
// CHECK-NEXT: %[[R:.*]] = linalg.generic
637+
// CHECK: bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32):
638+
// CHECK-NEXT: %[[A:.*]] = call @compute1(%[[BBA]]) : (f64) -> f64
639+
// CHECK-NEXT: %[[B:.*]] = call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32
640+
// CHECK-NEXT: linalg.yield %[[B]] : i32
641+
// CHECK-NEXT: } -> tensor<1x8xi32>
642+
%1 = linalg.generic {
643+
indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
644+
iterator_types = ["parallel", "parallel"]}
645+
ins(%0 : tensor<1x8xf64>)
646+
outs(%arg1 : tensor<1x8xi32>) {
647+
^bb0(%a: f64, %b: i32):
648+
%r = call @compute2(%a, %b) : (f64, i32) -> i32
649+
linalg.yield %r : i32
650+
} -> tensor<1x8xi32>
651+
652+
// CHECK-NEXT: return %[[R]] : tensor<1x8xi32>
653+
return %1 : tensor<1x8xi32>
654+
}

0 commit comments

Comments
 (0)