Skip to content

Commit 0cf8447

Browse files
authored
[MLIR][SCF] Fold dim ops of iter_args to respective init_args (#109973)
Fold dim ops of iter_args to dim ops of their respective init args. E.g.: ``` %0 = ... : tensor<?x?xf32> scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) { %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> ... } ``` is folded to: ``` %0 = ... : tensor<?x?xf32> scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) { %1 = tensor.dim %0, %c0 : tensor<?x?xf32> ... } ```
1 parent 8ea2b41 commit 0cf8447

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,44 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
103103
return success();
104104
}
105105
};
106+
107+
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
108+
///
109+
/// ```
110+
/// %0 = ... : tensor<?x?xf32>
111+
/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
112+
/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
113+
/// ...
114+
/// }
115+
/// ```
116+
///
117+
/// is folded to:
118+
///
119+
/// ```
120+
/// %0 = ... : tensor<?x?xf32>
121+
/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
122+
/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
123+
/// ...
124+
/// }
125+
/// ```
126+
struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
127+
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
128+
129+
LogicalResult matchAndRewrite(tensor::DimOp dimOp,
130+
PatternRewriter &rewriter) const final {
131+
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
132+
if (!blockArg)
133+
return failure();
134+
auto loopLikeOp =
135+
dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
136+
if (!loopLikeOp)
137+
return failure();
138+
Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
139+
rewriter.modifyOpInPlace(
140+
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
141+
return success();
142+
}
143+
};
106144
} // namespace
107145

108146
//===----------------------------------------------------------------------===//
@@ -127,8 +165,8 @@ struct ResolveShapedTypeResultDimsPass final
127165
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
128166
RewritePatternSet &patterns) {
129167
patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
130-
DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
131-
patterns.getContext());
168+
DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
169+
IterArgsToInitArgs>(patterns.getContext());
132170
}
133171

134172
void memref::populateResolveShapedTypeResultDimsPatterns(

mlir/test/Dialect/MemRef/resolve-dim-ops.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,31 @@ func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
7171
%1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
7272
return %1 : index
7373
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: @iter_to_init_arg_loop_like
78+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
79+
// CHECK: %[[RESULT:.*]] = scf.forall
80+
// CHECK-SAME: shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
81+
// CHECK-NEXT: %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
82+
func.func @iter_to_init_arg_loop_like(
83+
%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
84+
%c0 = arith.constant 0 : index
85+
%c1 = arith.constant 1 : index
86+
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
87+
88+
%result = scf.forall (%i) = (%c0) to (%dim0)
89+
step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
90+
91+
%dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
92+
%slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
93+
: tensor<?x?xf32> to tensor<1x?xf32>
94+
95+
scf.forall.in_parallel {
96+
tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
97+
: tensor<1x?xf32> into tensor<?x?xf32>
98+
}
99+
}
100+
return %result : tensor<?x?xf32>
101+
}

0 commit comments

Comments
 (0)