Skip to content

Commit e3e20bf

Browse files
committed
feat: constant fold tosa.tile when input is splat.
1 parent 57dd987 commit e3e20bf

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,14 @@ DenseElementsAttr tile(DenseElementsAttr inputValues, ShapedType outputType) {
17361736
auto inputType = inputValues.getType();
17371737
auto baseType = inputType.getElementType();
17381738

1739+
if (inputValues.isSplat()) {
1740+
if (isa<IntegerType>(baseType))
1741+
return DenseElementsAttr::get(outputType,
1742+
inputValues.getSplatValue<APInt>());
1743+
return DenseElementsAttr::get(outputType,
1744+
inputValues.getSplatValue<APFloat>());
1745+
}
1746+
17391747
// Handle possible integer types
17401748
if (auto intType = dyn_cast<IntegerType>(baseType)) {
17411749
switch (intType.getWidth()) {

mlir/test/Dialect/Tosa/constant-tile.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,40 @@ func.func @tile_f16_many_dimensions() -> (tensor<6x2x2xf16>) {
106106
%1 = tosa.tile %0 {multiples = array<i64: 3, 2, 1>} : (tensor<3x1x1xf16>) -> tensor<6x2x2xf16>
107107
// NO-FOLDING-CHECK: tosa.tile
108108
return %1 : tensor<6x2x2xf16>
109+
}
110+
111+
// CHECK-LABEL: @tile_i1_splat
112+
func.func @tile_i1_splat() -> (tensor<1x2x2x2xi1>) {
113+
// CHECK: "tosa.const"() <{value = dense<false> : tensor<1x2x2x2xi1>}>
114+
%0 = "tosa.const"() <{value = dense<false> : tensor<1x1x1x1xi1>}> : () -> tensor<1x1x1x1xi1>
115+
%1 = tosa.tile %0 {multiples = array<i64: 1, 2, 2, 2>} : (tensor<1x1x1x1xi1>) -> tensor<1x2x2x2xi1>
116+
// NO-FOLDING-CHECK: tosa.tile
117+
return %1 : tensor<1x2x2x2xi1>
118+
}
119+
120+
// CHECK-LABEL: @tile_i32_splat
121+
func.func @tile_i32_splat() -> (tensor<1x2x2x2xi32>) {
122+
// CHECK: "tosa.const"() <{value = dense<2> : tensor<1x2x2x2xi32>}>
123+
%0 = "tosa.const"() <{value = dense<2> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32>
124+
%1 = tosa.tile %0 {multiples = array<i64: 1, 2, 2, 2>} : (tensor<1x1x1x1xi32>) -> tensor<1x2x2x2xi32>
125+
// NO-FOLDING-CHECK: tosa.tile
126+
return %1 : tensor<1x2x2x2xi32>
127+
}
128+
129+
// CHECK-LABEL: @tile_f16_splat
130+
func.func @tile_f16_splat() -> (tensor<1x2x2x2xf16>) {
131+
// CHECK: "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x2x2x2xf16>}>
132+
%0 = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1x1xf16>}> : () -> tensor<1x1x1x1xf16>
133+
%1 = tosa.tile %0 {multiples = array<i64: 1, 2, 2, 2>} : (tensor<1x1x1x1xf16>) -> tensor<1x2x2x2xf16>
134+
// NO-FOLDING-CHECK: tosa.tile
135+
return %1 : tensor<1x2x2x2xf16>
136+
}
137+
138+
// CHECK-LABEL: @tile_bf16_splat
139+
func.func @tile_bf16_splat() -> (tensor<1x2x2x2xbf16>) {
140+
// CHECK: "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x2x2x2xbf16>}>
141+
%0 = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1x1xbf16>}> : () -> tensor<1x1x1x1xbf16>
142+
%1 = tosa.tile %0 {multiples = array<i64: 1, 2, 2, 2>} : (tensor<1x1x1x1xbf16>) -> tensor<1x2x2x2xbf16>
143+
// NO-FOLDING-CHECK: tosa.tile
144+
return %1 : tensor<1x2x2x2xbf16>
109145
}

0 commit comments

Comments
 (0)