|
15 | 15 | #include "mlir/Dialect/Math/IR/Math.h"
|
16 | 16 | #include "mlir/Dialect/StandardOps/IR/Ops.h"
|
17 | 17 | #include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
| 18 | +#include "mlir/IR/Matchers.h" |
18 | 19 | #include "mlir/IR/PatternMatch.h"
|
19 | 20 | #include "mlir/Transforms/DialectConversion.h"
|
20 | 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
@@ -438,6 +439,48 @@ class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
|
438 | 439 |
|
439 | 440 | rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
440 | 441 | reshape, resultTy, args[0], reassociationMap);
|
| 442 | + |
| 443 | + return success(); |
| 444 | + } |
| 445 | +}; |
| 446 | + |
| 447 | +class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> { |
| 448 | +public: |
| 449 | + using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern; |
| 450 | + |
| 451 | + LogicalResult matchAndRewrite(tosa::TransposeOp op, |
| 452 | + PatternRewriter &rewriter) const final { |
| 453 | + DenseIntElementsAttr perms; |
| 454 | + if (!matchPattern(op.perms(), m_Constant(&perms))) { |
| 455 | + return failure(); |
| 456 | + } |
| 457 | + |
| 458 | + auto resultTy = op.getType().cast<ShapedType>(); |
| 459 | + if (!resultTy.hasStaticShape()) |
| 460 | + return failure(); |
| 461 | + |
| 462 | + SmallVector<AffineExpr, 2> inputExprs; |
| 463 | + inputExprs.resize(resultTy.getRank()); |
| 464 | + for (auto permutation : llvm::enumerate(perms.getIntValues())) { |
| 465 | + inputExprs[permutation.value().getZExtValue()] = |
| 466 | + rewriter.getAffineDimExpr(permutation.index()); |
| 467 | + } |
| 468 | + |
| 469 | + auto initTensor = rewriter.create<linalg::InitTensorOp>( |
| 470 | + op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(), |
| 471 | + resultTy.getElementType()); |
| 472 | + |
| 473 | + SmallVector<AffineMap, 2> affineMaps = { |
| 474 | + AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, |
| 475 | + rewriter.getContext()), |
| 476 | + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; |
| 477 | + |
| 478 | + rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
| 479 | + op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps, |
| 480 | + getNParallelLoopsAttrs(resultTy.getRank()), |
| 481 | + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
| 482 | + nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin()); |
| 483 | + }); |
441 | 484 | return success();
|
442 | 485 | }
|
443 | 486 | };
|
@@ -478,5 +521,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
478 | 521 | PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
|
479 | 522 | PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
|
480 | 523 | IdentityNConverter<tosa::IdentityOp>,
|
481 |
| - IdentityNConverter<tosa::IdentityNOp>, ReshapeOpConverter>(context); |
| 524 | + IdentityNConverter<tosa::IdentityNOp>, |
| 525 | + ReshapeOpConverter, TransposeConverter>(context); |
482 | 526 | }
|
0 commit comments