@@ -103,6 +103,44 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
103
103
return success ();
104
104
}
105
105
};
106
+
107
+ // / Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
108
+ // /
109
+ // / ```
110
+ // / %0 = ... : tensor<?x?xf32>
111
+ // / scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
112
+ // / %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
113
+ // / ...
114
+ // / }
115
+ // / ```
116
+ // /
117
+ // / is folded to:
118
+ // /
119
+ // / ```
120
+ // / %0 = ... : tensor<?x?xf32>
121
+ // / scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
122
+ // / %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
123
+ // / ...
124
+ // / }
125
+ // / ```
126
+ struct IterArgsToInitArgs : public OpRewritePattern <tensor::DimOp> {
127
+ using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
128
+
129
+ LogicalResult matchAndRewrite (tensor::DimOp dimOp,
130
+ PatternRewriter &rewriter) const final {
131
+ auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource ());
132
+ if (!blockArg)
133
+ return failure ();
134
+ auto loopLikeOp =
135
+ dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock ()->getParentOp ());
136
+ if (!loopLikeOp)
137
+ return failure ();
138
+ Value initArg = loopLikeOp.getTiedLoopInit (blockArg)->get ();
139
+ rewriter.modifyOpInPlace (
140
+ dimOp, [&]() { dimOp.getSourceMutable ().assign (initArg); });
141
+ return success ();
142
+ }
143
+ };
106
144
} // namespace
107
145
108
146
// ===----------------------------------------------------------------------===//
@@ -127,8 +165,8 @@ struct ResolveShapedTypeResultDimsPass final
127
165
void memref::populateResolveRankedShapedTypeResultDimsPatterns (
128
166
RewritePatternSet &patterns) {
129
167
patterns.add <DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
130
- DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
131
- patterns.getContext ());
168
+ DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
169
+ IterArgsToInitArgs>( patterns.getContext ());
132
170
}
133
171
134
172
void memref::populateResolveShapedTypeResultDimsPatterns (
0 commit comments