@@ -1378,7 +1378,10 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1378
1378
return success ();
1379
1379
}
1380
1380
1381
- ArrayRef<int64_t > scale = op.getScale ();
1381
+ SmallVector<int64_t > scale;
1382
+ if (!tosa::getConstShapeValue (op.getScale ().getDefiningOp (), scale)) {
1383
+ return failure ();
1384
+ }
1382
1385
1383
1386
// Collapse the unit width and height away.
1384
1387
SmallVector<ReassociationExprs, 4 > reassociationMap (2 );
@@ -1440,105 +1443,6 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1440
1443
}
1441
1444
};
1442
1445
1443
- // TOSA resize with width or height of 1 may be broadcasted to a wider
1444
- // dimension. This is done by materializing a new tosa.resize without
1445
- // the broadcasting behavior, and an explicit broadcast afterwards.
1446
- class MaterializeResizeBroadcast : public OpRewritePattern <tosa::ResizeOp> {
1447
- public:
1448
- using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1449
-
1450
- LogicalResult matchAndRewrite (tosa::ResizeOp op,
1451
- PatternRewriter &rewriter) const final {
1452
- Location loc = op.getLoc ();
1453
- ImplicitLocOpBuilder builder (loc, rewriter);
1454
- auto input = op.getInput ();
1455
- auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
1456
- auto resultTy = dyn_cast<RankedTensorType>(op.getType ());
1457
-
1458
- if (!inputTy || !resultTy)
1459
- return rewriter.notifyMatchFailure (op,
1460
- " requires ranked input/output types" );
1461
-
1462
- auto batch = inputTy.getDimSize (0 );
1463
- auto channels = inputTy.getDimSize (3 );
1464
- auto inputH = inputTy.getDimSize (1 );
1465
- auto inputW = inputTy.getDimSize (2 );
1466
- auto outputH = resultTy.getDimSize (1 );
1467
- auto outputW = resultTy.getDimSize (2 );
1468
-
1469
- if ((inputH != 1 || outputH == 1 ) && (inputW != 1 || outputW == 1 ))
1470
- return rewriter.notifyMatchFailure (
1471
- op, " tosa.resize has no broadcasting behavior" );
1472
-
1473
- // For any dimension that is broadcastable we generate a width of 1
1474
- // on the output.
1475
- llvm::SmallVector<int64_t > resizeShape;
1476
- resizeShape.push_back (batch);
1477
- resizeShape.push_back (inputH == 1 ? 1 : outputH);
1478
- resizeShape.push_back (inputW == 1 ? 1 : outputW);
1479
- resizeShape.push_back (channels);
1480
-
1481
- auto resizeTy = resultTy.clone (resizeShape);
1482
- auto resize =
1483
- builder.create <tosa::ResizeOp>(resizeTy, input, op->getAttrs ());
1484
-
1485
- // Collapse an unit result dims.
1486
- SmallVector<ReassociationExprs, 4 > reassociationMap (2 );
1487
- reassociationMap[0 ].push_back (builder.getAffineDimExpr (0 ));
1488
- reassociationMap.back ().push_back (builder.getAffineDimExpr (1 ));
1489
- if (inputH != 1 )
1490
- reassociationMap.push_back ({});
1491
- reassociationMap.back ().push_back (builder.getAffineDimExpr (2 ));
1492
- if (inputW != 1 )
1493
- reassociationMap.push_back ({});
1494
- reassociationMap.back ().push_back (builder.getAffineDimExpr (3 ));
1495
-
1496
- llvm::SmallVector<int64_t > collapseShape = {batch};
1497
- if (inputH != 1 )
1498
- collapseShape.push_back (outputH);
1499
- if (inputW != 1 )
1500
- collapseShape.push_back (outputW);
1501
- collapseShape.push_back (channels);
1502
-
1503
- auto collapseTy = resultTy.clone (collapseShape);
1504
- Value collapse = builder.create <tensor::CollapseShapeOp>(collapseTy, resize,
1505
- reassociationMap);
1506
-
1507
- // Broadcast the collapsed shape to the output result.
1508
- llvm::SmallVector<Value> outputDynSize;
1509
- if (inputTy.isDynamicDim (0 ))
1510
- outputDynSize.push_back (builder.create <tensor::DimOp>(input, 0 ));
1511
- if (inputTy.isDynamicDim (3 ))
1512
- outputDynSize.push_back (builder.create <tensor::DimOp>(input, 3 ));
1513
-
1514
- SmallVector<utils::IteratorType> iterators (resultTy.getRank (),
1515
- utils::IteratorType::parallel);
1516
- Value empty = builder.create <tensor::EmptyOp>(
1517
- resultTy.getShape (), resultTy.getElementType (), outputDynSize);
1518
-
1519
- SmallVector<AffineExpr, 4 > inputExprs{rewriter.getAffineDimExpr (0 )};
1520
- if (inputH != 1 )
1521
- inputExprs.push_back (rewriter.getAffineDimExpr (1 ));
1522
- if (inputW != 1 )
1523
- inputExprs.push_back (rewriter.getAffineDimExpr (2 ));
1524
- inputExprs.push_back (rewriter.getAffineDimExpr (3 ));
1525
-
1526
- auto inputMap = AffineMap::get (resultTy.getRank (), /* symbolCount=*/ 0 ,
1527
- inputExprs, rewriter.getContext ());
1528
-
1529
- auto outputMap = rewriter.getMultiDimIdentityMap (resultTy.getRank ());
1530
- rewriter.replaceOpWithNewOp <linalg::GenericOp>(
1531
- op, resultTy, ValueRange{collapse}, ValueRange{empty},
1532
- ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1533
- [=](OpBuilder &b, Location loc, ValueRange args) {
1534
- Value value = args[0 ];
1535
- b.create <linalg::YieldOp>(loc, value);
1536
- });
1537
-
1538
- return success ();
1539
- }
1540
- };
1541
-
1542
1446
class GenericResizeConverter : public OpRewritePattern <tosa::ResizeOp> {
1543
1447
public:
1544
1448
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
@@ -1595,9 +1499,14 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1595
1499
Value inY = b.create <arith::IndexCastOp>(b.getI32Type (), y);
1596
1500
Value inX = b.create <arith::IndexCastOp>(b.getI32Type (), x);
1597
1501
1598
- ArrayRef<int64_t > offset = op.getOffset ();
1599
- ArrayRef<int64_t > border = op.getBorder ();
1600
- ArrayRef<int64_t > scale = op.getScale ();
1502
+ SmallVector<int64_t > scale, offset, border;
1503
+ if (!tosa::getConstShapeValue (op.getScale ().getDefiningOp (), scale) ||
1504
+ !tosa::getConstShapeValue (op.getOffset ().getDefiningOp (), offset) ||
1505
+ !tosa::getConstShapeValue (op.getBorder ().getDefiningOp (), border)) {
1506
+ return rewriter.notifyMatchFailure (
1507
+ op, " tosa.resize scale/offset/border should have compile time "
1508
+ " constant values." );
1509
+ }
1601
1510
1602
1511
Value yScaleN, yScaleD, xScaleN, xScaleD;
1603
1512
yScaleN = b.create <arith::ConstantOp>(b.getI32IntegerAttr (scale[0 ]));
@@ -2607,8 +2516,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
2607
2516
/* benefit=*/ 100 );
2608
2517
patterns->add <ResizeUnaryConverter>(patterns->getContext (),
2609
2518
/* benefit=*/ 200 );
2610
- patterns->add <MaterializeResizeBroadcast>(patterns->getContext (),
2611
- /* benefit=*/ 300 );
2612
2519
2613
2520
patterns->add <
2614
2521
// clang-format off
0 commit comments