Skip to content

Commit f5f1a5c

Browse files
committed
[mlir][Linalg] Handle fusion on tensors for projected permutation.
In the past, the reshape op can be folded only if the indexing map is permutation in consumer's usage. We can relax to condition to be projected permutation. This patch still limits the fusion for scalar cases. Scalar case is a corner case, because we need to decide where to put extra dims. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D92466
1 parent f5d5291 commit f5f1a5c

File tree

4 files changed

+123
-15
lines changed

4 files changed

+123
-15
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,12 @@ Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
118118
/// dimension is statically known, or -1 otherwise.
119119
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
120120

121-
/// Returns the statically-known loop ranges of the `linalgOp`. Applies the
122-
/// inverse of the concatenated indexing maps to the result of `getStaticShape`.
123-
/// Returns None if inverting the concatenated indexing map fails. Returns -1
121+
/// Returns the statically-known loop ranges of the `linalgOp`. Composes
122+
/// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`.
123+
/// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1
124124
/// for non-statically-known loop ranges.
125125
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
126+
126127
/// Apply the permutation defined by `permutation` to `inVec`.
127128
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
128129
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -411,21 +411,19 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
411411
unsigned fusedTensorIndex) {
412412
// Is fusable only if:
413413
// - The linalgOp is a generic op, or an indexed_generic.
414-
// - All the indexing maps for operands in linalgOp are projected
414+
// - All the indexing maps for operands and results in linalgOp are projected
415415
// permutations.
416-
// - The indexing map at the position representing the fused tensor is a
417-
// permutation.
416+
// - The fused tensor is not a scalar.
418417
// - All the loops in linalgOp are parallel loops.
419418
return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
420419
linalgOp.hasTensorSemantics() &&
421-
llvm::all_of(linalgOp.indexing_maps().getValue().take_front(
422-
linalgOp.getNumInputs()),
420+
llvm::all_of(linalgOp.indexing_maps().getValue(),
423421
[](Attribute attr) {
424422
return attr.cast<AffineMapAttr>()
425423
.getValue()
426424
.isProjectedPermutation();
427425
}) &&
428-
linalgOp.getIndexingMap(fusedTensorIndex).isPermutation() &&
426+
linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
429427
llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
430428
return attr.cast<StringAttr>().getValue() ==
431429
getParallelIteratorTypeName();
@@ -446,18 +444,22 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
446444
reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
447445
RankedTensorType expandedType =
448446
isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
449-
RankedTensorType foldedType =
450-
isExpanding ? reshapeOp.getSrcType() : reshapeOp.getResultType();
451447
AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
452448

453449
// The reshape is folding/expanding consecutive dimensions. Given the indexing
454450
// map of the fused tensor find the number of dimensions each of the loops of
455451
// the original op is expanded into. Also record the shape of the expanded
456452
// dimensions.
457453
ArrayRef<int64_t> expandedShape = expandedType.getShape();
458-
SmallVector<unsigned, 4> numFoldedDims(foldedType.getRank(), 0);
454+
Optional<SmallVector<int64_t, 4>> origOpLoopRange =
455+
getStaticLoopRanges(linalgOp);
456+
if (!origOpLoopRange) {
457+
linalgOp.emitError("unable to find loop range for operation");
458+
return llvm::None;
459+
}
460+
SmallVector<unsigned, 4> numFoldedDims(fusedIndexMap.getNumDims(), 1);
459461
SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
460-
foldedType.getRank());
462+
fusedIndexMap.getNumDims());
461463
auto reassociationMaps = reshapeOp.getReassociationMaps();
462464
for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
463465
unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
@@ -467,6 +469,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
467469
expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
468470
expandedDimsShape[pos].assign(shape.begin(), shape.end());
469471
}
472+
// The remaining dimensions remain the same.
473+
for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
474+
if (expandedDimsShape[i].empty())
475+
expandedDimsShape[i] = {(*origOpLoopRange)[i]};
470476

471477
if (isa<IndexedGenericOp>(linalgOp.getOperation())) {
472478
// For indexed generic op, the region contains arguments that represent the
@@ -476,6 +482,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
476482
// front) are statically know. For dynamic case, we would need shape
477483
// information on these dimensions to get these.
478484
for (auto &expandedShape : expandedDimsShape) {
485+
if (expandedShape.size() == 1)
486+
continue;
479487
for (int64_t expandedDimShape : llvm::make_range(
480488
std::next(expandedShape.begin()), expandedShape.end())) {
481489
if (ShapedType::isDynamic(expandedDimShape)) {

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,18 @@ SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
104104
auto shape = v.getType().cast<ShapedType>().getShape();
105105
res.append(shape.begin(), shape.end());
106106
}
107+
if (linalgOp.getNumInitTensors())
108+
return res;
109+
for (Value v : linalgOp.getOperation()->getResults()) {
110+
auto shape = v.getType().cast<ShapedType>().getShape();
111+
res.append(shape.begin(), shape.end());
112+
}
107113
return res;
108114
}
109115

110116
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
111117
SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
112-
AffineMap invertedMap =
113-
inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
118+
AffineMap invertedMap = linalgOp.getShapesToLoopsMap();
114119
if (!invertedMap)
115120
return {};
116121
return invertedMap.compose(viewSizes);

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,97 @@ func @reshape_as_consumer_permutation
344344
// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]]
345345
// CHECK: %[[T10:.+]] = index_cast %[[ARG7]]
346346
// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]]
347+
348+
// -----
349+
350+
func @reshape_as_producer_projected_permutation
351+
(%arg0 : tensor<33x8x?xi32>) -> tensor<264x?x4xi32> {
352+
%0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>,
353+
affine_map<(d0, d1, d2) -> (d2)>]
354+
: tensor<33x8x?xi32> into tensor<264x?xi32>
355+
%1 = linalg.indexed_generic
356+
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
357+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
358+
iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<264x?xi32>) {
359+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32): // no predecessors
360+
%2 = index_cast %arg1 : index to i32
361+
%3 = addi %arg4, %2 : i32
362+
%4 = index_cast %arg2 : index to i32
363+
%5 = addi %3, %4 : i32
364+
%6 = index_cast %arg3 : index to i32
365+
%7 = addi %5, %6 : i32
366+
linalg.yield %7 : i32
367+
} -> tensor<264x?x4xi32>
368+
return %1 : tensor<264x?x4xi32>
369+
}
370+
371+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
372+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
373+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
374+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
375+
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)>
376+
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
377+
// CHECK: @reshape_as_producer_projected_permutation
378+
// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32>
379+
// CHECK: %[[RES:.+]] = linalg.indexed_generic
380+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
381+
// CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>)
382+
// CHECK: ^{{.+}}(
383+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
384+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
385+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
386+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index,
387+
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32)
388+
// CHECK: %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG1]], %[[ARG2]])
389+
// CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32
390+
// CHECK: %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32
391+
// CHECK: %[[T3:.+]] = index_cast %[[ARG3]] : index to i32
392+
// CHECK: %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32
393+
// CHECK: %[[T5:.+]] = index_cast %[[ARG4]] : index to i32
394+
// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
395+
// CHECK: linalg.yield %[[T6]] : i32
396+
// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
397+
// CHECK-SAME: [#[[MAP3]], #[[MAP4]], #[[MAP5]]]
398+
// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
399+
// CHECK: return %[[RES2]] : tensor<264x?x4xi32>
400+
401+
// -----
402+
403+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
404+
#map1 = affine_map<(d0, d1) -> (d1, d0)>
405+
func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
406+
%arg1 : tensor<?x?xf32>) ->
407+
tensor<?x?x4x5xf32>
408+
{
409+
%0 = linalg.generic {
410+
indexing_maps = [#map0, #map0, #map1],
411+
iterator_types = ["parallel", "parallel"]}
412+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
413+
^bb0(%arg3: f32, %arg4: f32): // no predecessors
414+
%1 = mulf %arg3, %arg4 : f32
415+
linalg.yield %1 : f32
416+
} -> tensor<?x?xf32>
417+
%1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
418+
affine_map<(i, j, k, l) -> (j, k, l)>] :
419+
tensor<?x?xf32> into tensor<?x?x4x5xf32>
420+
return %1 : tensor<?x?x4x5xf32>
421+
}
422+
423+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
424+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
425+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
426+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
427+
// CHECK: func @generic_op_reshape_consumer_fusion_projected
428+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
429+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
430+
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
431+
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
432+
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
433+
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
434+
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
435+
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
436+
// CHECK: %[[T2:.+]] = linalg.generic
437+
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
438+
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
439+
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
440+
// CHECK: return %[[T2]] : tensor<?x?x4x5xf32>

0 commit comments

Comments
 (0)