Skip to content

Commit 89d5551

Browse files
committed
[mlir][tosa] Add constant folding for tosa.slice
If the input to a tosa.slice operation is a splat we can just replace with another splat. If the result is a single element, replacing with a splat is universally useful. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D132499
1 parent ecde303 commit 89d5551

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,30 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
555555
auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
556556
auto outputTy = getType().dyn_cast<RankedTensorType>();
557557

558-
if (!inputTy || !outputTy || inputTy != outputTy)
558+
if (!inputTy || !outputTy)
559559
return {};
560-
if (inputTy.hasStaticShape())
560+
561+
if (inputTy == outputTy && inputTy.hasStaticShape())
561562
return getInput();
562563

564+
if (!operands[0])
565+
return {};
566+
567+
auto operand = operands[0].cast<ElementsAttr>();
568+
if (operand.isSplat() && outputTy.hasStaticShape()) {
569+
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
570+
}
571+
572+
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
573+
outputTy.getNumElements() == 1) {
574+
llvm::SmallVector<uint64_t> indices;
575+
for (auto val : getStart()) {
576+
indices.push_back(val.cast<IntegerAttr>().getInt());
577+
}
578+
auto value = operand.getValues<Attribute>()[indices];
579+
return SplatElementsAttr::get(outputTy, value);
580+
}
581+
563582
return {};
564583
}
565584

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,25 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
161161
// CHECK: return %[[THREE]]
162162
return %add : tensor<10xf32>
163163
}
164+
165+
// -----
166+
167+
// CHECK-LABEL: @slice_splat
168+
func.func @slice_splat() -> tensor<1x1x1xi32> {
169+
// CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}
170+
%splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
171+
%slice = "tosa.slice"(%splat) { size = [1, 1, 1], start = [1, 2, 3] } : (tensor<4x5x6xi32>) -> tensor<1x1x1xi32>
172+
// CHECK: return %[[SLICE]]
173+
return %slice : tensor<1x1x1xi32>
174+
}
175+
176+
// -----
177+
178+
// CHECK-LABEL: @slice_singleton
179+
func.func @slice_singleton() -> tensor<1x1xi32> {
180+
%splat = "tosa.const"() {value = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
181+
// CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<4> : tensor<1x1xi32>}
182+
%slice = "tosa.slice"(%splat) { size = [1, 1], start = [1, 1] } : (tensor<3x3xi32>) -> tensor<1x1xi32>
183+
// CHECK: return %[[SLICE]]
184+
return %slice : tensor<1x1xi32>
185+
}

0 commit comments

Comments
 (0)