@@ -40,22 +40,43 @@ using llvm::SetVector;
40
40
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
41
41
" __internal_linalg_transform__" ;
42
42
43
- LogicalResult mlir::linalg::tileLinalgOpAndSetMarker (
44
- PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
45
- StringRef linalgMarker, ArrayRef<unsigned > permutation) {
43
+ using TileFn = Optional<TiledLinalgOp>(OpBuilder &, LinalgOp, ArrayRef<int64_t >,
44
+ ArrayRef<unsigned >, OperationFolder *);
45
+
46
+ static LogicalResult
47
+ tileLinalgOpAndSetMarkerImpl (TileFn tileFn, PatternRewriter &rewriter,
48
+ Operation *op, ArrayRef<int64_t > sizes,
49
+ StringRef linalgMarker,
50
+ ArrayRef<unsigned > permutation) {
46
51
assert (permutation.empty () || permutation.size () == sizes.size ());
47
- auto tileRes = tileLinalgOperation (rewriter, op, sizes, permutation);
52
+ auto tileRes = tileFn (rewriter, op, sizes, permutation, /* folder= */ nullptr );
48
53
if (!tileRes)
49
54
return failure ();
50
55
tileRes->op .setAttr (LinalgTransforms::kLinalgTransformMarker ,
51
56
rewriter.getStringAttr (linalgMarker));
52
57
return success ();
53
58
}
54
59
55
- LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker (
60
+ LogicalResult mlir::linalg::tileLinalgOpAndSetMarker (
56
61
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
57
- ArrayRef<int64_t > operandIndicesToFuse, StringRef linalgMarker) {
58
- auto tileRes = tileLinalgOperation (rewriter, op, sizes);
62
+ StringRef linalgMarker, ArrayRef<unsigned > permutation) {
63
+ return tileLinalgOpAndSetMarkerImpl (tileLinalgOp, rewriter, op, sizes,
64
+ linalgMarker, permutation);
65
+ }
66
+ LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker (
67
+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
68
+ StringRef linalgMarker, ArrayRef<unsigned > permutation) {
69
+ return tileLinalgOpAndSetMarkerImpl (tileLinalgOpToParallelLoops, rewriter, op,
70
+ sizes, linalgMarker, permutation);
71
+ }
72
+
73
+ static LogicalResult
74
+ tileAndFuseLinalgOpAndSetMarkerImpl (TileFn tileFn, PatternRewriter &rewriter,
75
+ Operation *op, ArrayRef<int64_t > sizes,
76
+ ArrayRef<int64_t > operandIndicesToFuse,
77
+ StringRef linalgMarker) {
78
+ auto tileRes =
79
+ tileFn (rewriter, op, sizes, /* permutation=*/ {}, /* folder=*/ nullptr );
59
80
if (!tileRes)
60
81
return failure ();
61
82
tileRes->op .setAttr (LinalgTransforms::kLinalgTransformMarker ,
@@ -89,6 +110,20 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
89
110
return success ();
90
111
}
91
112
113
+ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker (
114
+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
115
+ ArrayRef<int64_t > operandIndicesToFuse, StringRef linalgMarker) {
116
+ return tileAndFuseLinalgOpAndSetMarkerImpl (
117
+ tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker);
118
+ }
119
+ LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker (
120
+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
121
+ ArrayRef<int64_t > operandIndicesToFuse, StringRef linalgMarker) {
122
+ return tileAndFuseLinalgOpAndSetMarkerImpl (
123
+ tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse,
124
+ linalgMarker);
125
+ }
126
+
92
127
bool mlir::linalg::detail::isProducedByOpOfTypeImpl (
93
128
Operation *consumerOp, Value consumedView,
94
129
function_ref<bool (Operation *)> isaOpType) {
0 commit comments