Skip to content

Commit 7060920

Browse files
committed
Relax FuseTensorReshapeOpAsproducer identity mapping constraint
Differential Revision: https://reviews.llvm.org/D88869
1 parent 5e4409f commit 7060920

File tree

2 files changed

+110
-7
lines changed

2 files changed

+110
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
326326
if ((asProducer && returnType.getRank() < operandType.getRank()) ||
327327
(!asProducer && operandType.getRank() < returnType.getRank()))
328328
return false;
329-
return useIndexMap.isIdentity();
329+
return useIndexMap.isPermutation();
330330
}
331331

332332
/// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
@@ -381,10 +381,13 @@ struct FuseTensorReshapeOpAsProducer {
381381
return attr.cast<AffineMapAttr>().getValue();
382382
}));
383383

384+
// Accepted consumer maps are either identity or permutation.
385+
auto invMap = inversePermutation(fusedIndexMaps[consumerIdx]);
386+
384387
// Compute the indexing map to use for the operand of the producer.
385-
AffineMap modifiedMap = linearizeCollapsedDims(
386-
fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
387-
producer.getReassociationMaps());
388+
AffineMap modifiedMap =
389+
linearizeCollapsedDims(invMap, producer.getResultType().getShape(),
390+
producer.getReassociationMaps());
388391
for (AffineExpr expr : modifiedMap.getResults()) {
389392
if (!expr.isPureAffine())
390393
return nullptr;
@@ -439,10 +442,13 @@ struct FuseTensorReshapeOpAsConsumer {
439442
producer.indexing_maps(), [](Attribute attr) -> AffineMap {
440443
return attr.cast<AffineMapAttr>().getValue();
441444
}));
445+
446+
auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
447+
442448
// Compute the indexing map to use for the operand of the producer.
443-
AffineMap modifiedMap = linearizeCollapsedDims(
444-
producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
445-
consumer.getReassociationMaps());
449+
AffineMap modifiedMap =
450+
linearizeCollapsedDims(invMap, consumer.getSrcType().getShape(),
451+
consumer.getReassociationMaps());
446452
for (AffineExpr expr : modifiedMap.getResults()) {
447453
if (!expr.isPureAffine())
448454
return nullptr;

mlir/test/Dialect/Linalg/fusion-tensor.mlir

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,100 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
558558
// CHECK: linalg.indexed_generic
559559
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
560560
// CHECK-NOT: linalg.tensor_reshape
561+
562+
// -----
563+
564+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
565+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
566+
567+
#map0 = affine_map<(d0, d1, d2) -> (d0)>
568+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
569+
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
570+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
571+
func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
572+
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
573+
%1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
574+
^bb0(%arg2: f32): // no predecessors
575+
linalg.yield %arg2 : f32
576+
} -> tensor<3x7x5xf32>
577+
return %1 : tensor<3x7x5xf32>
578+
}
579+
580+
// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion
581+
// CHECK-NOT: linalg.tensor_reshape
582+
// CHECK: linalg.generic
583+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
584+
// CHECK-NOT: linalg.tensor_reshape
585+
586+
// -----
587+
588+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
589+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
590+
591+
#map0 = affine_map<(d0, d1, d2) -> (d0)>
592+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
593+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
594+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
595+
func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
596+
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
597+
%1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
598+
^bb0(%arg2: f32): // no predecessors
599+
linalg.yield %arg2 : f32
600+
} -> tensor<5x7x3xf32>
601+
return %1 : tensor<5x7x3xf32>
602+
}
603+
604+
// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion
605+
// CHECK-NOT: linalg.tensor_reshape
606+
// CHECK: linalg.generic
607+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
608+
// CHECK-NOT: linalg.tensor_reshape
609+
610+
// -----
611+
612+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
613+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
614+
615+
#map0 = affine_map<(d0, d1, d2) -> (d0)>
616+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
617+
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
618+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
619+
func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
620+
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
621+
%1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
622+
^bb0(%arg2: f32): // no predecessors
623+
linalg.yield %arg2 : f32
624+
} -> tensor<5x3x7xf32>
625+
return %1 : tensor<5x3x7xf32>
626+
}
627+
628+
// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion
629+
// CHECK-NOT: linalg.tensor_reshape
630+
// CHECK: linalg.generic
631+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
632+
// CHECK-NOT: linalg.tensor_reshape
633+
634+
// -----
635+
636+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
637+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
638+
639+
640+
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
641+
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
642+
#map2 = affine_map<(d0, d1, d2) -> (d0)>
643+
#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
644+
func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
645+
%0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) {
646+
^bb0(%arg2: f32): // no predecessors
647+
linalg.yield %arg2 : f32
648+
} -> tensor<5x3x7xf32>
649+
%1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32>
650+
return %1 : tensor<5x21xf32>
651+
}
652+
653+
// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion
654+
// CHECK-NOT: linalg.tensor_reshape
655+
// CHECK: linalg.generic
656+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
657+
// CHECK-NOT: linalg.tensor_reshape

0 commit comments

Comments
 (0)