@@ -1382,7 +1382,10 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1382
1382
return success ();
1383
1383
}
1384
1384
1385
- ArrayRef<int64_t > scale = op.getScale ();
1385
+ SmallVector<int64_t > scale;
1386
+ if (!tosa::getConstShapeValue (op.getScale ().getDefiningOp (), scale)) {
1387
+ return failure ();
1388
+ }
1386
1389
1387
1390
// Collapse the unit width and height away.
1388
1391
SmallVector<ReassociationExprs, 4 > reassociationMap (2 );
@@ -1444,105 +1447,6 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1444
1447
}
1445
1448
};
1446
1449
1447
- // TOSA resize with width or height of 1 may be broadcasted to a wider
1448
- // dimension. This is done by materializing a new tosa.resize without
1449
- // the broadcasting behavior, and an explicit broadcast afterwards.
1450
- class MaterializeResizeBroadcast : public OpRewritePattern <tosa::ResizeOp> {
1451
- public:
1452
- using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1453
-
1454
- LogicalResult matchAndRewrite (tosa::ResizeOp op,
1455
- PatternRewriter &rewriter) const final {
1456
- Location loc = op.getLoc ();
1457
- ImplicitLocOpBuilder builder (loc, rewriter);
1458
- auto input = op.getInput ();
1459
- auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
1460
- auto resultTy = dyn_cast<RankedTensorType>(op.getType ());
1461
-
1462
- if (!inputTy || !resultTy)
1463
- return rewriter.notifyMatchFailure (op,
1464
- " requires ranked input/output types" );
1465
-
1466
- auto batch = inputTy.getDimSize (0 );
1467
- auto channels = inputTy.getDimSize (3 );
1468
- auto inputH = inputTy.getDimSize (1 );
1469
- auto inputW = inputTy.getDimSize (2 );
1470
- auto outputH = resultTy.getDimSize (1 );
1471
- auto outputW = resultTy.getDimSize (2 );
1472
-
1473
- if ((inputH != 1 || outputH == 1 ) && (inputW != 1 || outputW == 1 ))
1474
- return rewriter.notifyMatchFailure (
1475
- op, " tosa.resize has no broadcasting behavior" );
1476
-
1477
- // For any dimension that is broadcastable we generate a width of 1
1478
- // on the output.
1479
- llvm::SmallVector<int64_t > resizeShape;
1480
- resizeShape.push_back (batch);
1481
- resizeShape.push_back (inputH == 1 ? 1 : outputH);
1482
- resizeShape.push_back (inputW == 1 ? 1 : outputW);
1483
- resizeShape.push_back (channels);
1484
-
1485
- auto resizeTy = resultTy.clone (resizeShape);
1486
- auto resize =
1487
- builder.create <tosa::ResizeOp>(resizeTy, input, op->getAttrs ());
1488
-
1489
- // Collapse an unit result dims.
1490
- SmallVector<ReassociationExprs, 4 > reassociationMap (2 );
1491
- reassociationMap[0 ].push_back (builder.getAffineDimExpr (0 ));
1492
- reassociationMap.back ().push_back (builder.getAffineDimExpr (1 ));
1493
- if (inputH != 1 )
1494
- reassociationMap.push_back ({});
1495
- reassociationMap.back ().push_back (builder.getAffineDimExpr (2 ));
1496
- if (inputW != 1 )
1497
- reassociationMap.push_back ({});
1498
- reassociationMap.back ().push_back (builder.getAffineDimExpr (3 ));
1499
-
1500
- llvm::SmallVector<int64_t > collapseShape = {batch};
1501
- if (inputH != 1 )
1502
- collapseShape.push_back (outputH);
1503
- if (inputW != 1 )
1504
- collapseShape.push_back (outputW);
1505
- collapseShape.push_back (channels);
1506
-
1507
- auto collapseTy = resultTy.clone (collapseShape);
1508
- Value collapse = builder.create <tensor::CollapseShapeOp>(collapseTy, resize,
1509
- reassociationMap);
1510
-
1511
- // Broadcast the collapsed shape to the output result.
1512
- llvm::SmallVector<Value> outputDynSize;
1513
- if (inputTy.isDynamicDim (0 ))
1514
- outputDynSize.push_back (builder.create <tensor::DimOp>(input, 0 ));
1515
- if (inputTy.isDynamicDim (3 ))
1516
- outputDynSize.push_back (builder.create <tensor::DimOp>(input, 3 ));
1517
-
1518
- SmallVector<utils::IteratorType> iterators (resultTy.getRank (),
1519
- utils::IteratorType::parallel);
1520
- Value empty = builder.create <tensor::EmptyOp>(
1521
- resultTy.getShape (), resultTy.getElementType (), outputDynSize);
1522
-
1523
- SmallVector<AffineExpr, 4 > inputExprs{rewriter.getAffineDimExpr (0 )};
1524
- if (inputH != 1 )
1525
- inputExprs.push_back (rewriter.getAffineDimExpr (1 ));
1526
- if (inputW != 1 )
1527
- inputExprs.push_back (rewriter.getAffineDimExpr (2 ));
1528
- inputExprs.push_back (rewriter.getAffineDimExpr (3 ));
1529
-
1530
- auto inputMap = AffineMap::get (resultTy.getRank (), /* symbolCount=*/ 0 ,
1531
- inputExprs, rewriter.getContext ());
1532
-
1533
- auto outputMap = rewriter.getMultiDimIdentityMap (resultTy.getRank ());
1534
- rewriter.replaceOpWithNewOp <linalg::GenericOp>(
1535
- op, resultTy, ValueRange{collapse}, ValueRange{empty},
1536
- ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1537
- [=](OpBuilder &b, Location loc, ValueRange args) {
1538
- Value value = args[0 ];
1539
- b.create <linalg::YieldOp>(loc, value);
1540
- });
1541
-
1542
- return success ();
1543
- }
1544
- };
1545
-
1546
1450
class GenericResizeConverter : public OpRewritePattern <tosa::ResizeOp> {
1547
1451
public:
1548
1452
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
@@ -1599,9 +1503,14 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1599
1503
Value inY = b.create <arith::IndexCastOp>(b.getI32Type (), y);
1600
1504
Value inX = b.create <arith::IndexCastOp>(b.getI32Type (), x);
1601
1505
1602
- ArrayRef<int64_t > offset = op.getOffset ();
1603
- ArrayRef<int64_t > border = op.getBorder ();
1604
- ArrayRef<int64_t > scale = op.getScale ();
1506
+ SmallVector<int64_t > scale, offset, border;
1507
+ if (!tosa::getConstShapeValue (op.getScale ().getDefiningOp (), scale) ||
1508
+ !tosa::getConstShapeValue (op.getOffset ().getDefiningOp (), offset) ||
1509
+ !tosa::getConstShapeValue (op.getBorder ().getDefiningOp (), border)) {
1510
+ return rewriter.notifyMatchFailure (
1511
+ op, " tosa.resize scale/offset/border should have compile time "
1512
+ " constant values." );
1513
+ }
1605
1514
1606
1515
Value yScaleN, yScaleD, xScaleN, xScaleD;
1607
1516
yScaleN = b.create <arith::ConstantOp>(b.getI32IntegerAttr (scale[0 ]));
@@ -2612,8 +2521,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
2612
2521
/* benefit=*/ 100 );
2613
2522
patterns->add <ResizeUnaryConverter>(patterns->getContext (),
2614
2523
/* benefit=*/ 200 );
2615
- patterns->add <MaterializeResizeBroadcast>(patterns->getContext (),
2616
- /* benefit=*/ 300 );
2617
2524
2618
2525
patterns->add <
2619
2526
// clang-format off
0 commit comments