Skip to content

Commit 3d27455

Browse files
authored
Merge pull request #187 from Xilinx/tiagot.extend_tosa_tile_folding_splat
feat: constant fold tosa.tile when input is splat.
2 parents 57dd987 + 98430c8 commit 3d27455

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1736,11 +1736,19 @@ DenseElementsAttr tile(DenseElementsAttr inputValues, ShapedType outputType) {
17361736
auto inputType = inputValues.getType();
17371737
auto baseType = inputType.getElementType();
17381738

1739+
if (inputValues.isSplat()) {
1740+
if (isa<IntegerType>(baseType))
1741+
return DenseElementsAttr::get(outputType,
1742+
inputValues.getSplatValue<APInt>());
1743+
return DenseElementsAttr::get(outputType,
1744+
inputValues.getSplatValue<APFloat>());
1745+
}
1746+
17391747
// Handle possible integer types
17401748
if (auto intType = dyn_cast<IntegerType>(baseType)) {
17411749
switch (intType.getWidth()) {
17421750
case 1:
1743-
// i1 has special alignment which is not handled by transposeTypeRaw.
1751+
// i1 has special alignment which is not handled by tileTypeRaw.
17441752
return tileType<bool>(inputValues, inputType, outputType);
17451753
case 8:
17461754
return tileTypeRaw<uint8_t>(inputValues, inputType, outputType);

mlir/test/Dialect/Tosa/constant-tile.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,40 @@ func.func @tile_f16_many_dimensions() -> (tensor<6x2x2xf16>) {
106106
%1 = tosa.tile %0 {multiples = array<i64: 3, 2, 1>} : (tensor<3x1x1xf16>) -> tensor<6x2x2xf16>
107107
// NO-FOLDING-CHECK: tosa.tile
108108
return %1 : tensor<6x2x2xf16>
109+
}
110+
111+
// CHECK-LABEL: @tile_i1_splat
112+
func.func @tile_i1_splat() -> (tensor<1x2x2x2xi1>) {
113+
// CHECK: "tosa.const"() <{value = dense<false> : tensor<1x2x2x2xi1>}>
114+
%0 = "tosa.const"() <{value = dense<false> : tensor<1x1x1x1xi1>}> : () -> tensor<1x1x1x1xi1>
115+
%1 = tosa.tile %0 {multiples = array<i64: 1, 2, 2, 2>} : (tensor<1x1x1x1xi1>) -> tensor<1x2x2x2xi1>
116+
// NO-FOLDING-CHECK: tosa.tile
117+
return %1 : tensor<1x2x2x2xi1>
118+
}
119+
120+
// CHECK-LABEL: @tile_i32_splat
121+
func.func @tile_i32_splat() -> (tensor<1x2x2x2xi32>) {
122+
// CHECK: "tosa.const"() <{value = dense<2> : tensor<1x2x2x2xi32>}>
123+
%0 = "tosa.const"() <{value = dense<2> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32>
124+
%1 = tosa.tile %0 {multiples = array<i64: 1, 2, 2, 2>} : (tensor<1x1x1x1xi32>) -> tensor<1x2x2x2xi32>
125+
// NO-FOLDING-CHECK: tosa.tile
126+
return %1 : tensor<1x2x2x2xi32>
127+
}
128+
129+
// CHECK-LABEL: @tile_f16_splat
130+
func.func @tile_f16_splat() -> (tensor<1x2x2x2xf16>) {
131+
// CHECK: "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x2x2x2xf16>}>
132+
%0 = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1x1xf16>}> : () -> tensor<1x1x1x1xf16>
133+
%1 = tosa.tile %0 {multiples = array<i64: 1, 2, 2, 2>} : (tensor<1x1x1x1xf16>) -> tensor<1x2x2x2xf16>
134+
// NO-FOLDING-CHECK: tosa.tile
135+
return %1 : tensor<1x2x2x2xf16>
136+
}
137+
138+
// CHECK-LABEL: @tile_bf16_splat
139+
func.func @tile_bf16_splat() -> (tensor<1x2x2x2xbf16>) {
140+
// CHECK: "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x2x2x2xbf16>}>
141+
%0 = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1x1xbf16>}> : () -> tensor<1x1x1x1xbf16>
142+
%1 = tosa.tile %0 {multiples = array<i64: 1, 2, 2, 2>} : (tensor<1x1x1x1xbf16>) -> tensor<1x2x2x2xbf16>
143+
// NO-FOLDING-CHECK: tosa.tile
144+
return %1 : tensor<1x2x2x2xbf16>
109145
}

0 commit comments

Comments
 (0)