Skip to content

Commit 5cd074f

Browse files
authored
[mlir] Add ReifyRankedShapedTypeOpInterface to tosa::TransposeOp (#88890)
1 parent bb95f5d commit 5cd074f

File tree

3 files changed

+63
-8
lines changed

3 files changed

+63
-8
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,7 +1502,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
15021502

15031503
let hasFolder = 1;
15041504
let hasVerifier = 1;
1505-
1505+
15061506
let extraClassDeclaration = [{
15071507
/// Returns true when two result types are compatible for this op;
15081508
/// Method used by InferTypeOpInterface.
@@ -1652,7 +1652,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
16521652

16531653
let hasFolder = 1;
16541654
let hasVerifier = 1;
1655-
1655+
16561656
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
16571657
}
16581658

@@ -1708,7 +1708,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
17081708
//===----------------------------------------------------------------------===//
17091709
// Operator: transpose
17101710
//===----------------------------------------------------------------------===//
1711-
def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose"> {
1711+
def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
1712+
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
17121713
let summary = "Transpose operator";
17131714

17141715
let description = [{
@@ -1835,9 +1836,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
18351836

18361837
| Mode | Input | Output |
18371838
|--------------------------|---------|---------|
1838-
| signed 8 to bool | int8 | Boolean |
1839-
| signed 16 to bool | int16 | Boolean |
1840-
| signed 32 to bool | int32 | Boolean |
1839+
| signed 8 to bool | int8 | Boolean |
1840+
| signed 16 to bool | int16 | Boolean |
1841+
| signed 32 to bool | int32 | Boolean |
18411842
| bool to 8 | Boolean | int8 |
18421843
| bool to 16 | Boolean | int16 |
18431844
| bool to 32 | Boolean | int32 |
@@ -1851,8 +1852,8 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
18511852
| float to signed 16 | float | int16 |
18521853
| signed 8 to float | int8 | float |
18531854
| signed 16 to float | int16 | float |
1854-
| float 32 to float 64 | float32 | float64 |
1855-
| float 64 to float 32 | float64 | float32 |
1855+
| float 32 to float 64 | float32 | float64 |
1856+
| float 64 to float 32 | float64 | float32 |
18561857
}];
18571858

18581859
let arguments = (ins

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,32 @@ LogicalResult tosa::TransposeOp::verify() {
11191119
return success();
11201120
}
11211121

1122+
LogicalResult TransposeOp::reifyResultShapes(
1123+
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1124+
1125+
SmallVector<int64_t> transposePerms;
1126+
if (getConstantPerms(transposePerms).failed())
1127+
return failure();
1128+
1129+
Value input = getInput1();
1130+
auto inputType = input.getType().cast<TensorType>();
1131+
1132+
SmallVector<OpFoldResult> returnedDims(inputType.getRank());
1133+
for (auto dim : transposePerms) {
1134+
int64_t dimInInput = transposePerms[dim];
1135+
if (inputType.isDynamicDim(dimInInput))
1136+
returnedDims[dim] =
1137+
builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
1138+
.getResult();
1139+
else
1140+
returnedDims[dim] =
1141+
builder.getIndexAttr(inputType.getDimSize(dimInInput));
1142+
}
1143+
1144+
reifiedReturnShapes.emplace_back(std::move(returnedDims));
1145+
return success();
1146+
}
1147+
11221148
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
11231149
MLIRContext *context, ::std::optional<Location> location,
11241150
GatherOp::Adaptor adaptor,

mlir/test/Dialect/MemRef/resolve-dim-ops.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,31 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
2525
%0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
2626
return %0 : index
2727
}
28+
29+
// -----
30+
31+
// CHECK-LABEL: func.func @dynamic_dim_of_transpose_op(
32+
// CHECK-SAME: %[[arg:.*]]: tensor<1x2x?x8xi8>) -> index {
33+
// CHECK-NEXT: %[[c2:.*]] = arith.constant 2
34+
// CHECK-NEXT: tensor.dim %[[arg]], %[[c2]]
35+
// CHECK-NEXT: return
36+
func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
37+
%0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
38+
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x?x8xi8>, tensor<4xi32>) -> tensor<1x8x2x?xi8>
39+
%c3 = arith.constant 3 : index
40+
%dim = tensor.dim %1, %c3 : tensor<1x8x2x?xi8>
41+
return %dim : index
42+
}
43+
44+
// -----
45+
46+
// CHECK-LABEL: func.func @static_dim_of_transpose_op(
47+
// CHECK: arith.constant 100 : index
48+
// CHECK: return
49+
func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
50+
%0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
51+
%1 = tosa.transpose %arg0, %0 : (tensor<1x100x?x8xi8>, tensor<4xi32>) -> tensor<1x8x100x?xi8>
52+
%c2 = arith.constant 2 : index
53+
%dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
54+
return %dim : index
55+
}

0 commit comments

Comments
 (0)