Skip to content

Commit 01055ed

Browse files
[mlir][linalg] Move linalg.fill folding into linalg.generic pattern from canonicalization to elementwise fusion
Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D122847
1 parent 857d699 commit 01055ed

File tree

4 files changed

+79
-79
lines changed

4 files changed

+79
-79
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -913,35 +913,12 @@ struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> {
913913
return success();
914914
}
915915
};
916-
917-
/// Fold linalg.fill into linalg.generic
918-
struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
919-
using OpRewritePattern<GenericOp>::OpRewritePattern;
920-
921-
LogicalResult matchAndRewrite(GenericOp genericOp,
922-
PatternRewriter &rewriter) const override {
923-
if (!genericOp.hasTensorSemantics())
924-
return failure();
925-
bool fillFound = false;
926-
Block &payload = genericOp.region().front();
927-
for (OpOperand *opOperand : genericOp.getInputOperands()) {
928-
FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
929-
if (fillOp) {
930-
fillFound = true;
931-
payload.getArgument(opOperand->getOperandNumber())
932-
.replaceAllUsesWith(fillOp.value());
933-
}
934-
}
935-
// fail if there are no FillOps to fold.
936-
return success(fillFound);
937-
}
938-
};
939916
} // namespace
940917

941918
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
942919
MLIRContext *context) {
943920
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
944-
DeadArgsGenericOpInputs, FoldFillWithGenericOp>(context);
921+
DeadArgsGenericOpInputs>(context);
945922
}
946923

947924
LogicalResult GenericOp::fold(ArrayRef<Attribute>,

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2215,8 +2215,31 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
22152215
return success();
22162216
}
22172217
};
2218-
} // namespace
22192218

2219+
/// Fold linalg.fill into linalg.generic
2220+
struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2221+
using OpRewritePattern<GenericOp>::OpRewritePattern;
2222+
2223+
LogicalResult matchAndRewrite(GenericOp genericOp,
2224+
PatternRewriter &rewriter) const override {
2225+
if (!genericOp.hasTensorSemantics())
2226+
return failure();
2227+
bool fillFound = false;
2228+
Block &payload = genericOp.region().front();
2229+
for (OpOperand *opOperand : genericOp.getInputOperands()) {
2230+
if (!genericOp.payloadUsesValueFromOperand(opOperand))
2231+
continue;
2232+
FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2233+
if (!fillOp)
2234+
continue;
2235+
fillFound = true;
2236+
payload.getArgument(opOperand->getOperandNumber())
2237+
.replaceAllUsesWith(fillOp.value());
2238+
}
2239+
return success(fillFound);
2240+
}
2241+
};
2242+
} // namespace
22202243
//===---------------------------------------------------------------------===//
22212244
// Methods that add patterns described in this file to a pattern list.
22222245
//===---------------------------------------------------------------------===//
@@ -2261,7 +2284,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
22612284
patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
22622285
FoldConstantTranspose>(context,
22632286
options.controlElementwiseOpsFusionFn);
2264-
patterns.add<RemoveOutsDependency>(context);
2287+
patterns.add<RemoveOutsDependency, FoldFillWithGenericOp>(context);
22652288
populateSparseTensorRewriting(patterns);
22662289
populateFoldReshapeOpsByExpansionPatterns(patterns,
22672290
options.controlFoldingReshapesFn);

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -343,59 +343,6 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
343343

344344
// -----
345345

346-
// CHECK-LABEL: func @fold_fill_generic_basic
347-
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
348-
// CHECK-NOT: linalg.fill
349-
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
350-
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
351-
// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
352-
#map0 = affine_map<(d0) -> (d0)>
353-
func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
354-
%c0 = arith.constant 0 : index
355-
%cst = arith.constant 7.0 : f32
356-
%0 = tensor.dim %arg0, %c0 : tensor<?xf32>
357-
%1 = linalg.init_tensor [%0] : tensor<?xf32>
358-
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
359-
%3 = linalg.init_tensor [%0] : tensor<?xf32>
360-
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
361-
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
362-
%5 = arith.addf %arg1, %arg2 : f32
363-
linalg.yield %5 : f32
364-
} -> tensor<?xf32>
365-
return %4 : tensor<?xf32>
366-
}
367-
368-
// -----
369-
370-
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
371-
// CHECK-NOT: linalg.fill
372-
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
373-
// CHECK-NOT: ins
374-
// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
375-
#map0 = affine_map<(d0, d1) -> (d0, d1)>
376-
#map1 = affine_map<(d0, d1) -> (d1, d0)>
377-
func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
378-
%c0 = arith.constant 0 : index
379-
%c1 = arith.constant 0 : index
380-
%cst1 = arith.constant 7.0 : f32
381-
%cst2 = arith.constant 6.0 : f32
382-
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
383-
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
384-
%2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
385-
%3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
386-
%4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
387-
%5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
388-
%6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
389-
%7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
390-
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
391-
%8 = arith.divf %arg1, %arg2 : f32
392-
linalg.yield %8 : f32
393-
} -> tensor<?x?xf32>
394-
return %7 : tensor<?x?xf32>
395-
}
396-
397-
// -----
398-
399346
// CHECK-LABEL: func @remove_deadargs_generic_basic
400347
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
401348
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,3 +975,56 @@ func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tens
975975
// CHECK: %[[PRODUCER:.+]] = linalg.generic
976976
// CHECK: linalg.generic
977977
// CHECK-SAME: ins(%[[PRODUCER]]
978+
979+
// -----
980+
981+
// CHECK-LABEL: func @fold_fill_generic_basic
982+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
983+
// CHECK-NOT: linalg.fill
984+
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
985+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
986+
// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
987+
#map0 = affine_map<(d0) -> (d0)>
988+
func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
989+
%c0 = arith.constant 0 : index
990+
%cst = arith.constant 7.0 : f32
991+
%0 = tensor.dim %arg0, %c0 : tensor<?xf32>
992+
%1 = linalg.init_tensor [%0] : tensor<?xf32>
993+
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
994+
%3 = linalg.init_tensor [%0] : tensor<?xf32>
995+
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
996+
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
997+
%5 = arith.addf %arg1, %arg2 : f32
998+
linalg.yield %5 : f32
999+
} -> tensor<?xf32>
1000+
return %4 : tensor<?xf32>
1001+
}
1002+
1003+
// -----
1004+
1005+
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
1006+
// CHECK-NOT: linalg.fill
1007+
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
1008+
// CHECK-NOT: ins
1009+
// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
1010+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
1011+
#map1 = affine_map<(d0, d1) -> (d1, d0)>
1012+
func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
1013+
%c0 = arith.constant 0 : index
1014+
%c1 = arith.constant 0 : index
1015+
%cst1 = arith.constant 7.0 : f32
1016+
%cst2 = arith.constant 6.0 : f32
1017+
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
1018+
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
1019+
%2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
1020+
%3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
1021+
%4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
1022+
%5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
1023+
%6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
1024+
%7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
1025+
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
1026+
%8 = arith.divf %arg1, %arg2 : f32
1027+
linalg.yield %8 : f32
1028+
} -> tensor<?x?xf32>
1029+
return %7 : tensor<?x?xf32>
1030+
}

0 commit comments

Comments
 (0)