Skip to content

Commit dfb7701

Browse files
HsiangkaiTai78641lhutton1
committed
[mlir][tosa] Make TOSA RESIZE's scale, offset, border as Input
Move the `scale`, `offset`, and `border` parameters of the RESIZE operator in the MLIR TOSA dialect from attributes to inputs and update lit tests appropriately. Add the verifier of the `tosa::ResizeOp` operation. Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
1 parent dd369c7 commit dfb7701

File tree

13 files changed

+444
-213
lines changed

13 files changed

+444
-213
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,9 +1823,9 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18231823

18241824
let arguments = (ins
18251825
Tosa_Tensor4D:$input,
1826-
Tosa_IntArrayAttr4:$scale,
1827-
Tosa_IntArrayAttr2:$offset,
1828-
Tosa_IntArrayAttr2:$border,
1826+
Rank4TosaShape:$scale,
1827+
Rank2TosaShape:$offset,
1828+
Rank2TosaShape:$border,
18291829
Tosa_ResizeTypeAttr:$mode
18301830
);
18311831

@@ -1834,6 +1834,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18341834
);
18351835

18361836
let hasFolder = 1;
1837+
let hasVerifier = 1;
18371838
}
18381839

18391840
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
240240
bool getConstShapeValue(Operation *op,
241241
llvm::SmallVector<int64_t> &result_shape);
242242

243+
// returns a small vector of int64_t values that attr contains
244+
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
245+
const int rank);
243246
} // namespace tosa
244247
} // namespace mlir
245248

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 12 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,10 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
13821382
return success();
13831383
}
13841384

1385-
ArrayRef<int64_t> scale = op.getScale();
1385+
SmallVector<int64_t> scale;
1386+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale)) {
1387+
return failure();
1388+
}
13861389

13871390
// Collapse the unit width and height away.
13881391
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1444,105 +1447,6 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
14441447
}
14451448
};
14461449

1447-
// TOSA resize with width or height of 1 may be broadcasted to a wider
1448-
// dimension. This is done by materializing a new tosa.resize without
1449-
// the broadcasting behavior, and an explicit broadcast afterwards.
1450-
class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1451-
public:
1452-
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1453-
1454-
LogicalResult matchAndRewrite(tosa::ResizeOp op,
1455-
PatternRewriter &rewriter) const final {
1456-
Location loc = op.getLoc();
1457-
ImplicitLocOpBuilder builder(loc, rewriter);
1458-
auto input = op.getInput();
1459-
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1460-
auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1461-
1462-
if (!inputTy || !resultTy)
1463-
return rewriter.notifyMatchFailure(op,
1464-
"requires ranked input/output types");
1465-
1466-
auto batch = inputTy.getDimSize(0);
1467-
auto channels = inputTy.getDimSize(3);
1468-
auto inputH = inputTy.getDimSize(1);
1469-
auto inputW = inputTy.getDimSize(2);
1470-
auto outputH = resultTy.getDimSize(1);
1471-
auto outputW = resultTy.getDimSize(2);
1472-
1473-
if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1474-
return rewriter.notifyMatchFailure(
1475-
op, "tosa.resize has no broadcasting behavior");
1476-
1477-
// For any dimension that is broadcastable we generate a width of 1
1478-
// on the output.
1479-
llvm::SmallVector<int64_t> resizeShape;
1480-
resizeShape.push_back(batch);
1481-
resizeShape.push_back(inputH == 1 ? 1 : outputH);
1482-
resizeShape.push_back(inputW == 1 ? 1 : outputW);
1483-
resizeShape.push_back(channels);
1484-
1485-
auto resizeTy = resultTy.clone(resizeShape);
1486-
auto resize =
1487-
builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1488-
1489-
// Collapse an unit result dims.
1490-
SmallVector<ReassociationExprs, 4> reassociationMap(2);
1491-
reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1492-
reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1493-
if (inputH != 1)
1494-
reassociationMap.push_back({});
1495-
reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1496-
if (inputW != 1)
1497-
reassociationMap.push_back({});
1498-
reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1499-
1500-
llvm::SmallVector<int64_t> collapseShape = {batch};
1501-
if (inputH != 1)
1502-
collapseShape.push_back(outputH);
1503-
if (inputW != 1)
1504-
collapseShape.push_back(outputW);
1505-
collapseShape.push_back(channels);
1506-
1507-
auto collapseTy = resultTy.clone(collapseShape);
1508-
Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1509-
reassociationMap);
1510-
1511-
// Broadcast the collapsed shape to the output result.
1512-
llvm::SmallVector<Value> outputDynSize;
1513-
if (inputTy.isDynamicDim(0))
1514-
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1515-
if (inputTy.isDynamicDim(3))
1516-
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1517-
1518-
SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1519-
utils::IteratorType::parallel);
1520-
Value empty = builder.create<tensor::EmptyOp>(
1521-
resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1522-
1523-
SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1524-
if (inputH != 1)
1525-
inputExprs.push_back(rewriter.getAffineDimExpr(1));
1526-
if (inputW != 1)
1527-
inputExprs.push_back(rewriter.getAffineDimExpr(2));
1528-
inputExprs.push_back(rewriter.getAffineDimExpr(3));
1529-
1530-
auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1531-
inputExprs, rewriter.getContext());
1532-
1533-
auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1534-
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1535-
op, resultTy, ValueRange{collapse}, ValueRange{empty},
1536-
ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1537-
[=](OpBuilder &b, Location loc, ValueRange args) {
1538-
Value value = args[0];
1539-
b.create<linalg::YieldOp>(loc, value);
1540-
});
1541-
1542-
return success();
1543-
}
1544-
};
1545-
15461450
class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15471451
public:
15481452
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
@@ -1599,9 +1503,14 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15991503
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
16001504
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
16011505

1602-
ArrayRef<int64_t> offset = op.getOffset();
1603-
ArrayRef<int64_t> border = op.getBorder();
1604-
ArrayRef<int64_t> scale = op.getScale();
1506+
SmallVector<int64_t> scale, offset, border;
1507+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale) ||
1508+
!tosa::getConstShapeValue(op.getOffset().getDefiningOp(), offset) ||
1509+
!tosa::getConstShapeValue(op.getBorder().getDefiningOp(), border)) {
1510+
return rewriter.notifyMatchFailure(
1511+
op, "tosa.resize scale/offset/border should have compile time "
1512+
"constant values.");
1513+
}
16051514

16061515
Value yScaleN, yScaleD, xScaleN, xScaleD;
16071516
yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
@@ -2612,8 +2521,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
26122521
/*benefit=*/100);
26132522
patterns->add<ResizeUnaryConverter>(patterns->getContext(),
26142523
/*benefit=*/200);
2615-
patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2616-
/*benefit=*/300);
26172524

26182525
patterns->add<
26192526
// clang-format off

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -975,9 +975,22 @@ OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
975975
// Fold away cases where a tosa.resize operation returns a copy
976976
// of the input image.
977977
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
978-
ArrayRef<int64_t> offset = getOffset();
979-
ArrayRef<int64_t> border = getBorder();
980-
ArrayRef<int64_t> scale = getScale();
978+
auto scaleAttr =
979+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
980+
auto offsetAttr =
981+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
982+
auto borderAttr =
983+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
984+
if (!scaleAttr || !offsetAttr || !borderAttr) {
985+
return {};
986+
}
987+
988+
auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
989+
auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
990+
auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
991+
if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
992+
return {};
993+
}
981994

982995
// Check unit scaling.
983996
if (scale[0] != scale[1] || scale[2] != scale[3]) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,9 +1656,14 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
16561656
(inputWidth == ShapedType::kDynamic))
16571657
return failure();
16581658

1659-
llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1660-
llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1661-
llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1659+
SmallVector<int64_t> scaleInt, offsetInt, borderInt;
1660+
if (!tosa::getConstShapeValue(adaptor.getScale().getDefiningOp(), scaleInt) ||
1661+
!tosa::getConstShapeValue(adaptor.getOffset().getDefiningOp(),
1662+
offsetInt) ||
1663+
!tosa::getConstShapeValue(adaptor.getBorder().getDefiningOp(),
1664+
borderInt)) {
1665+
return failure();
1666+
}
16621667

16631668
// Compute the output shape based on attributes: scale, offset, and border.
16641669
outputShape[1] =
@@ -1675,6 +1680,90 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
16751680
return success();
16761681
}
16771682

1683+
LogicalResult tosa::ResizeOp::verify() {
1684+
const Value input = getInput();
1685+
const Value output = getOutput();
1686+
const RankedTensorType inputType =
1687+
llvm::dyn_cast<RankedTensorType>(input.getType());
1688+
const RankedTensorType outputType =
1689+
llvm::dyn_cast<RankedTensorType>(output.getType());
1690+
1691+
if (!inputType)
1692+
return emitOpError("expect a ranked input tensor");
1693+
if (!outputType)
1694+
return emitOpError("expect a ranked output tensor");
1695+
1696+
const int64_t oh = outputType.getDimSize(1);
1697+
const int64_t ow = outputType.getDimSize(2);
1698+
const int64_t ih = inputType.getDimSize(1);
1699+
const int64_t iw = inputType.getDimSize(2);
1700+
1701+
SmallVector<int64_t> scaleValues;
1702+
SmallVector<int64_t> offsetValues;
1703+
SmallVector<int64_t> borderValues;
1704+
if (!tosa::getConstShapeValue(getScale().getDefiningOp(), scaleValues) ||
1705+
!tosa::getConstShapeValue(getOffset().getDefiningOp(), offsetValues) ||
1706+
!tosa::getConstShapeValue(getBorder().getDefiningOp(), borderValues)) {
1707+
// Skip following checks if shape is not constant
1708+
return success();
1709+
}
1710+
1711+
if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
1712+
return emitOpError("expect all scale values to be > 0, got ")
1713+
<< scaleValues;
1714+
1715+
const int64_t scaleYN = scaleValues[0];
1716+
const int64_t scaleYD = scaleValues[1];
1717+
const int64_t scaleXN = scaleValues[2];
1718+
const int64_t scaleXD = scaleValues[3];
1719+
1720+
const int64_t offsetY = offsetValues[0];
1721+
const int64_t offsetX = offsetValues[1];
1722+
1723+
const int64_t borderY = borderValues[0];
1724+
const int64_t borderX = borderValues[1];
1725+
1726+
auto idivCheck = [](const int64_t lhs,
1727+
const int64_t rhs) -> std::optional<int64_t> {
1728+
if (lhs % rhs != 0)
1729+
return std::nullopt;
1730+
return lhs / rhs;
1731+
};
1732+
1733+
if (ih != ShapedType::kDynamic) {
1734+
const std::optional<int64_t> calculatedOutHeightMinusOne =
1735+
idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1736+
if (!calculatedOutHeightMinusOne.has_value())
1737+
return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
1738+
"border_y ")
1739+
<< "to be wholly divisible by scale_y_d, got ((" << ih
1740+
<< " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1741+
<< ") / " << scaleYD;
1742+
const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1743+
if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1744+
return emitOpError("calculated output height did not match expected: ")
1745+
<< "calculated=" << calculatedOutHeight << ", expected=" << oh;
1746+
}
1747+
1748+
if (iw != ShapedType::kDynamic) {
1749+
const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
1750+
const std::optional<int64_t> calculatedOutWidthMinusOne =
1751+
idivCheck(scaledInWidth, scaleXD);
1752+
if (!calculatedOutWidthMinusOne.has_value())
1753+
return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
1754+
"border_x ")
1755+
<< "to be wholly divisible by scale_x_d, got ((" << iw
1756+
<< " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1757+
<< ") / " << scaleXD;
1758+
const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1759+
if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1760+
return emitOpError("calculated output width did not match expected: ")
1761+
<< "calculated=" << calculatedOutWidth << ", expected=" << ow;
1762+
}
1763+
1764+
return success();
1765+
}
1766+
16781767
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
16791768
MLIRContext *context, ::std::optional<Location> location,
16801769
ScatterOp::Adaptor adaptor,

0 commit comments

Comments
 (0)