Skip to content

Commit 087bc20

Browse files
committed
[MLIR][TOSA] Lower tosa.transpose to linalg.generic
Lowers the transpose operation to a generic linalg op when permutations is a constant value. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D97508
1 parent 2fcc3f4 commit 087bc20

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Math/IR/Math.h"
1616
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1717
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
18+
#include "mlir/IR/Matchers.h"
1819
#include "mlir/IR/PatternMatch.h"
1920
#include "mlir/Transforms/DialectConversion.h"
2021
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -438,6 +439,48 @@ class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
438439

439440
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
440441
reshape, resultTy, args[0], reassociationMap);
442+
443+
return success();
444+
}
445+
};
446+
447+
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
448+
public:
449+
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
450+
451+
LogicalResult matchAndRewrite(tosa::TransposeOp op,
452+
PatternRewriter &rewriter) const final {
453+
DenseIntElementsAttr perms;
454+
if (!matchPattern(op.perms(), m_Constant(&perms))) {
455+
return failure();
456+
}
457+
458+
auto resultTy = op.getType().cast<ShapedType>();
459+
if (!resultTy.hasStaticShape())
460+
return failure();
461+
462+
SmallVector<AffineExpr, 2> inputExprs;
463+
inputExprs.resize(resultTy.getRank());
464+
for (auto permutation : llvm::enumerate(perms.getIntValues())) {
465+
inputExprs[permutation.value().getZExtValue()] =
466+
rewriter.getAffineDimExpr(permutation.index());
467+
}
468+
469+
auto initTensor = rewriter.create<linalg::InitTensorOp>(
470+
op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
471+
resultTy.getElementType());
472+
473+
SmallVector<AffineMap, 2> affineMaps = {
474+
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
475+
rewriter.getContext()),
476+
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
477+
478+
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
479+
op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
480+
getNParallelLoopsAttrs(resultTy.getRank()),
481+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
482+
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
483+
});
441484
return success();
442485
}
443486
};
@@ -478,5 +521,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
478521
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
479522
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
480523
IdentityNConverter<tosa::IdentityOp>,
481-
IdentityNConverter<tosa::IdentityNOp>, ReshapeOpConverter>(context);
524+
IdentityNConverter<tosa::IdentityNOp>,
525+
ReshapeOpConverter, TransposeConverter>(context);
482526
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,21 @@ func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32
317317
// CHECK: return %arg0, %arg1
318318
return %2#0, %2#1 : tensor<1xf32>, tensor<1xi32>
319319
}
320+
321+
// -----
322+
323+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
324+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
325+
326+
// CHECK-LABEL: @test_transpose
327+
// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>)
328+
func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
329+
%0 = constant dense<[1, 2, 0]> : tensor<3xi32>
330+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3, 1]
331+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>)
332+
// CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
333+
// CHECK: linalg.yield [[ARG1]]
334+
// CHECK: }
335+
%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
336+
return
337+
}

0 commit comments

Comments
 (0)