Skip to content

Commit 11ac97c

Browse files
authored
[mlir][tosa] Move lowering of tosa.transpose to tosa-to-linalg-named (#75738)
Currently, there exists a pattern lowering `tosa.transpose` to `linalg.generic` in `tosa-to-linalg`. This patch removes that and instead adds a pattern lowering `tosa.transpose` to `linalg.transpose` in `tosa-to-linalg-named`. Lowering to the named linalg Op has the advantage that following optimization passes can easily identify transposition without having to perform pattern matching on linalg.generic Ops. The `linalg.transpose` can simply be generalized to a `linalg.generic` in a second step.
1 parent 3af59cf commit 11ac97c

File tree

6 files changed

+99
-154
lines changed

6 files changed

+99
-154
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,55 +1052,6 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
10521052
}
10531053
};
10541054

1055-
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1056-
public:
1057-
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1058-
1059-
LogicalResult matchAndRewrite(tosa::TransposeOp op,
1060-
PatternRewriter &rewriter) const final {
1061-
DenseIntElementsAttr perms;
1062-
if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
1063-
return rewriter.notifyMatchFailure(op, "unmatched permutation tensor");
1064-
}
1065-
1066-
auto loc = op.getLoc();
1067-
auto input = op->getOperand(0);
1068-
auto resultTy = cast<ShapedType>(op.getType());
1069-
1070-
SmallVector<Value> dynDims;
1071-
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
1072-
1073-
SmallVector<AffineExpr, 2> inputExprs;
1074-
inputExprs.resize(resultTy.getRank());
1075-
for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
1076-
auto index = permutation.index();
1077-
auto value = permutation.value().getZExtValue();
1078-
if (!resultTy.hasRank() || resultTy.isDynamicDim(index)) {
1079-
dynDims[index] = rewriter.create<tensor::DimOp>(loc, input, value);
1080-
}
1081-
inputExprs[value] = rewriter.getAffineDimExpr(index);
1082-
}
1083-
1084-
SmallVector<Value> filteredDims = condenseValues(dynDims);
1085-
1086-
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1087-
loc, resultTy.getShape(), resultTy.getElementType(), filteredDims);
1088-
1089-
SmallVector<AffineMap, 2> affineMaps = {
1090-
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1091-
rewriter.getContext()),
1092-
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1093-
1094-
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1095-
op, resultTy, op.getInput1(), ValueRange{emptyTensor}, affineMaps,
1096-
getNParallelLoopsAttrs(resultTy.getRank()),
1097-
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1098-
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1099-
});
1100-
return success();
1101-
}
1102-
};
1103-
11041055
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
11051056
public:
11061057
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -2454,7 +2405,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
24542405
ReverseConverter,
24552406
RFFT2dConverter,
24562407
TableConverter,
2457-
TileConverter,
2458-
TransposeConverter>(patterns->getContext());
2408+
TileConverter>(patterns->getContext());
24592409
// clang-format on
24602410
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Tensor/Utils/Utils.h"
2020
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
2121
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
22+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2223
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
2324
#include "mlir/IR/Matchers.h"
2425
#include "mlir/IR/PatternMatch.h"
@@ -984,6 +985,31 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
984985
}
985986
};
986987

988+
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
989+
public:
990+
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
991+
992+
LogicalResult matchAndRewrite(tosa::TransposeOp op,
993+
PatternRewriter &rewriter) const final {
994+
SmallVector<int64_t> constantPerms;
995+
if (failed(op.getConstantPerms(constantPerms)))
996+
return failure();
997+
998+
Location loc = op.getLoc();
999+
// The verifier should have made sure we have a valid permutation tensor.
1000+
assert(isPermutationVector(constantPerms) && "Expected valid permutation");
1001+
SmallVector<OpFoldResult> inputSizes =
1002+
tensor::getMixedSizes(rewriter, loc, op.getInput1());
1003+
auto permutedSizes =
1004+
applyPermutation<OpFoldResult>(inputSizes, constantPerms);
1005+
1006+
auto permutedInit = rewriter.create<tensor::EmptyOp>(
1007+
loc, permutedSizes, op.getInput1().getType().getElementType());
1008+
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1009+
op, op.getInput1(), permutedInit, constantPerms);
1010+
return success();
1011+
}
1012+
};
9871013
} // namespace
9881014

9891015
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
@@ -1004,6 +1030,8 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
10041030
MatMulConverter,
10051031
MaxPool2dConverter,
10061032
AvgPool2dConverter,
1007-
FullyConnectedConverter>(patterns->getContext());
1033+
FullyConnectedConverter,
1034+
TransposeConverter
1035+
>(patterns->getContext());
10081036
// clang-format on
10091037
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct TosaToLinalgNamed
6060
target.addIllegalOp<tosa::AvgPool2dOp>();
6161
target.addIllegalOp<tosa::MatMulOp>();
6262
target.addIllegalOp<tosa::FullyConnectedOp>();
63+
target.addIllegalOp<tosa::TransposeOp>();
6364

6465
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
6566

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
8888
// CHECK-LABEL: @fully_connected
8989
func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
9090
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
91-
// CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
91+
// CHECK: %[[TRANSPOSEDINIT:.+]] = tensor.empty() : tensor<3x6xf32>
92+
// CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT]] : tensor<3x6xf32>) permutation = [1, 0]
9293
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
9394

9495
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
@@ -110,7 +111,7 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
110111
// CHECK-LABEL: @quantized_fully_connected
111112
func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
112113
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
113-
// CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
114+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xi8>) permutation = [1, 0]
114115
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
115116

116117
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
@@ -136,7 +137,7 @@ func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %
136137
// CHECK: %[[C0:.+]] = arith.constant 0 : index
137138
// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
138139
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
139-
// CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
140+
// CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xf32>) permutation = [1, 0]
140141
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
141142

142143
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
@@ -377,7 +378,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
377378
// CHECK-LABEL: @conv2d_i8
378379
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
379380
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
380-
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
381+
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
381382
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
382383
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
383384
// CHECK: arith.extsi
@@ -398,7 +399,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
398399
// CHECK-LABEL: @conv2d_f32
399400
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
400401
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
401-
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
402+
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
402403

403404
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
404405
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
@@ -677,7 +678,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
677678
// CHECK-LABEL: @conv3d_f32
678679
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
679680
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
680-
// CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
681+
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
681682
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
682683
// CHECK: %[[BROADCAST:.+]] = linalg.generic
683684
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -701,7 +702,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
701702
// CHECK-LABEL: @conv3d_i8
702703
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
703704
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
704-
// CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
705+
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
705706
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
706707
// CHECK: %[[BROADCAST:.+]] = linalg.generic
707708
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -720,3 +721,63 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
720721
%0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32>
721722
return
722723
}
724+
725+
// -----
726+
727+
// CHECK-LABEL: @test_transpose
728+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
729+
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
730+
%0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
731+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x1xi32>
732+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x2x3xi32>) outs(%[[INIT]] : tensor<2x3x1xi32>) permutation = [1, 2, 0]
733+
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
734+
return
735+
}
736+
737+
// -----
738+
739+
// CHECK-LABEL: @test_transpose_dyn
740+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
741+
func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
742+
%0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32>
743+
// CHECK: %[[C1:.+]] = arith.constant 1
744+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
745+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
746+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs(%[[INIT]] : tensor<?x4x1x3xi32>) permutation = [1, 3, 0, 2]
747+
%1 = tosa.transpose %arg0, %0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x4x1x3xi32>
748+
return
749+
}
750+
751+
// -----
752+
753+
// CHECK-LABEL: @test_transpose_dyn_multiple_2d
754+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
755+
func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
756+
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
757+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
758+
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
759+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
760+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
761+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
762+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[INIT]] : tensor<?x?xf32>) permutation = [1, 0]
763+
%1 = tosa.transpose %arg0, %0 : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
764+
return
765+
}
766+
767+
// -----
768+
769+
// CHECK-LABEL: @test_transpose_dyn_multiple_3d
770+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
771+
func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
772+
%0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
773+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
774+
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
775+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
776+
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
777+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
778+
// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
779+
// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
780+
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) permutation = [2, 0, 1]
781+
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
782+
return
783+
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,3 @@ func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.u
3838
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
3939
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
4040
}
41-
42-
// -----
43-
44-
// check that --tosa-validate=strict-op-spec-alignment does not kick in because tosa-to-linalg-named comes before tosa-validate
45-
// this would have failed tosa strict-op-spec-alignment because perms of transpose is not constant
46-
// but tosa.transpose is lowered by tosa-to-linalg-named pass which is earlier than tosa-validate pass in the pipeline
47-
func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
48-
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
49-
return %0 : tensor<3x13x21xf32>
50-
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,6 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
820820
return
821821
}
822822

823-
824823
// -----
825824

826825
// CHECK-LABEL: @test_identity
@@ -836,90 +835,6 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
836835

837836
// -----
838837

839-
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
840-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
841-
842-
// CHECK-LABEL: @test_transpose
843-
// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>)
844-
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
845-
%0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
846-
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3x1xi32>
847-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>)
848-
// CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
849-
// CHECK: linalg.yield [[ARG1]]
850-
// CHECK: }
851-
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
852-
return
853-
}
854-
855-
// -----
856-
857-
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>
858-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
859-
860-
// CHECK-LABEL: @test_transpose_dyn
861-
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
862-
func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
863-
%0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32>
864-
// CHECK: %[[C1:.+]] = arith.constant 1
865-
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
866-
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
867-
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs([[OUT:%.+]] : tensor<?x4x1x3xi32>)
868-
// CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
869-
// CHECK: linalg.yield [[ARG1]]
870-
// CHECK: }
871-
%1 = tosa.transpose %arg0, %0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x4x1x3xi32>
872-
return
873-
}
874-
875-
// -----
876-
877-
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
878-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
879-
880-
// CHECK-LABEL: @test_transpose_dyn_multiple_2d
881-
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
882-
func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
883-
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
884-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
885-
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
886-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
887-
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
888-
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
889-
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs([[OUT:%.+]] : tensor<?x?xf32>)
890-
// CHECK: ^bb0([[ARG1:%.+]]: f32, [[ARG2:%.+]]: f32)
891-
// CHECK: linalg.yield [[ARG1]]
892-
// CHECK: }
893-
%1 = tosa.transpose %arg0, %0 : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
894-
return
895-
}
896-
897-
// -----
898-
899-
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
900-
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
901-
902-
// CHECK-LABEL: @test_transpose_dyn_multiple_3d
903-
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
904-
func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
905-
%0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
906-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
907-
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
908-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
909-
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
910-
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
911-
// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
912-
// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
913-
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) {
914-
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[OUT0:.*]]: f32):
915-
// CHECK: linalg.yield %[[IN0]] : f32
916-
// CHECK: } -> tensor<?x?x?xf32>
917-
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
918-
return
919-
}
920-
921-
// -----
922-
923838
// CHECK-LABEL: @reduce_float
924839
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
925840
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {

0 commit comments

Comments
 (0)