Skip to content

Commit 42cd9ae

Browse files
Fold linalg.fill -> linalg.copy (#72920)
1 parent 1f14173 commit 42cd9ae

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -803,14 +803,36 @@ struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
803803
}
804804
};
805805

806+
/// Fold fill with copy.
807+
struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
808+
using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
809+
810+
LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
811+
PatternRewriter &rewriter) const override {
812+
if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
813+
rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
814+
fillOp.getInputs(),
815+
copyOp.getOutputs());
816+
return success();
817+
}
818+
if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
819+
rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
820+
fillOp.getOutputs());
821+
return success();
822+
}
823+
return failure();
824+
}
825+
};
826+
806827
} // namespace
807828

808829
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
809830
MLIRContext *context) {
810-
results.add<FoldFillWithTensorExtract, FoldFillWithPack, FoldFillWithPad,
811-
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
812-
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
813-
FoldInsertPadIntoFill>(context);
831+
results
832+
.add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
833+
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
834+
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
835+
FoldInsertPadIntoFill>(context);
814836
}
815837

816838
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,3 +972,29 @@ func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<
972972
%3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
973973
return %3: tensor<?x?xf32>
974974
}
975+
// -----
976+
977+
// CHECK-LABEL: func @canonicalize_fill_to_copy_input(
978+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
979+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
980+
// CHECK: %[[ZERO:.+]] = arith.constant 0.0
981+
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
982+
func.func @canonicalize_fill_to_copy_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
983+
%c0 = arith.constant 0.0 : f32
984+
%fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
985+
%copy = linalg.copy ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
986+
return %copy : tensor<?x?xf32>
987+
}
988+
989+
// -----
990+
991+
// CHECK-LABEL: func @canonicalize_fill_to_copy_dest(
992+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
993+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
994+
// CHECK: linalg.copy ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[ARG0]] : tensor<?x?xf32>)
995+
func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
996+
%c0 = arith.constant 0.0 : f32
997+
%fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
998+
%copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
999+
return %copy : tensor<?x?xf32>
1000+
}

0 commit comments

Comments
 (0)