Skip to content

Commit c077a4f

Browse files
[mlir][Tensor] Add pattern to fold concats of empty. (#98994)
A concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to `populateFoldTensorEmptyPatterns`.
1 parent 2d42f84 commit c077a4f

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,38 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
136136
}
137137
};
138138

139+
// Fold concat operation where all the operands are empty.
140+
struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
141+
using OpRewritePattern<ConcatOp>::OpRewritePattern;
142+
143+
LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
144+
PatternRewriter &rewriter) const override {
145+
auto concatOperands = concatOp.getInputs();
146+
if (concatOperands.empty()) {
147+
return failure();
148+
}
149+
auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
150+
if (!firstEmptyOp) {
151+
return failure();
152+
}
153+
auto isDefinedByEmptyOp = [](Value v) -> bool {
154+
return v.getDefiningOp<tensor::EmptyOp>();
155+
};
156+
if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
157+
return rewriter.notifyMatchFailure(
158+
concatOp, "not all operands are defined by an empty op");
159+
}
160+
SmallVector<SmallVector<OpFoldResult>> resultShape;
161+
if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
162+
return rewriter.notifyMatchFailure(concatOp,
163+
"failed to get result shape");
164+
}
165+
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
166+
concatOp, resultShape[0], concatOp.getResultType().getElementType());
167+
return success();
168+
}
169+
};
170+
139171
} // namespace
140172

141173
void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
@@ -144,6 +176,7 @@ void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
144176
FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
145177
FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
146178
patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
147-
patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
148-
patterns.getContext(), /*benefit=*/1);
179+
patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
180+
FoldEmptyTensorWithUnPackOp>(patterns.getContext(),
181+
/*benefit=*/1);
149182
}

mlir/test/Dialect/Tensor/fold-empty-op.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,41 @@ func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index)
164164
// CHECK: tensor.empty{{.*}} : tensor<?x10x40xf32>
165165
// CHECK: tensor.extract_slice
166166
// CHECK: tensor.extract_slice
167+
168+
// -----
169+
170+
module attributes {transform.with_named_sequence} {
171+
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
172+
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
173+
transform.apply_patterns to %func_op {
174+
transform.apply_patterns.tensor.fold_tensor_empty
175+
} : !transform.op<"func.func">
176+
transform.yield
177+
}
178+
}
179+
180+
func.func @concats_of_empty(
181+
%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
182+
-> tensor<5x?x?xf32>
183+
{
184+
%0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
185+
%1 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
186+
%2 = tensor.concat dim(1) %0, %1 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
187+
return %2 : tensor<5x?x?xf32>
188+
}
189+
// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
190+
// CHECK: func @concats_of_empty(
191+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
192+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
193+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
194+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index)
195+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
196+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
197+
// CHECK-DAG: %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
198+
// CHECK-DAG: %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
199+
// CHECK: %[[D2:.+]] = tensor.dim %[[EMPTY0]], %[[C2]]
200+
// CHECK-DAG: %[[D0_1:.+]] = tensor.dim %[[EMPTY0]], %[[C1]]
201+
// CHECK-DAG: %[[D1_1:.+]] = tensor.dim %[[EMPTY1]], %[[C1]]
202+
// CHECK-DAG: %[[SUM:.+]] = affine.apply #[[MAP]]()[%[[D0_1]], %[[D1_1]]]
203+
// CHECK: %[[NEW_EMPTY:.+]] = tensor.empty(%[[SUM]], %[[D2]])
204+
// CHECK: return %[[NEW_EMPTY]]

0 commit comments

Comments
 (0)