Skip to content

Commit 1f08f3a

Browse files
HsiangkaiTai78641lhutton1
committed
[mlir][tosa] Make TOSA RESIZE's scale, offset, border as Input
Move the `sclae`, `scale`, and `offset` parameters of the RESIZE operator in the MLIR TOSA dialect from attributes to inputs and update lit tests appropriately. Add the verifier of the `tosa::ResizeOp` operation. Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
1 parent 5a4945f commit 1f08f3a

File tree

13 files changed

+430
-213
lines changed

13 files changed

+430
-213
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,9 +1853,9 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18531853

18541854
let arguments = (ins
18551855
Tosa_Tensor4D:$input,
1856-
Tosa_IntArrayAttr4:$scale,
1857-
Tosa_IntArrayAttr2:$offset,
1858-
Tosa_IntArrayAttr2:$border,
1856+
Rank4TosaShape:$scale,
1857+
Rank2TosaShape:$offset,
1858+
Rank2TosaShape:$border,
18591859
Tosa_ResizeTypeAttr:$mode
18601860
);
18611861

@@ -1864,6 +1864,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18641864
);
18651865

18661866
let hasFolder = 1;
1867+
let hasVerifier = 1;
18671868
}
18681869

18691870
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
237237
bool getConstShapeValue(Operation *op,
238238
llvm::SmallVector<int64_t> &result_shape);
239239

240+
// returns a small vector of int64_t values that attr contains
241+
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
242+
const int rank);
240243
} // namespace tosa
241244
} // namespace mlir
242245

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 12 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,10 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
13781378
return success();
13791379
}
13801380

1381-
ArrayRef<int64_t> scale = op.getScale();
1381+
SmallVector<int64_t> scale;
1382+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale)) {
1383+
return failure();
1384+
}
13821385

13831386
// Collapse the unit width and height away.
13841387
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1440,105 +1443,6 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
14401443
}
14411444
};
14421445

1443-
// TOSA resize with width or height of 1 may be broadcasted to a wider
1444-
// dimension. This is done by materializing a new tosa.resize without
1445-
// the broadcasting behavior, and an explicit broadcast afterwards.
1446-
class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1447-
public:
1448-
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1449-
1450-
LogicalResult matchAndRewrite(tosa::ResizeOp op,
1451-
PatternRewriter &rewriter) const final {
1452-
Location loc = op.getLoc();
1453-
ImplicitLocOpBuilder builder(loc, rewriter);
1454-
auto input = op.getInput();
1455-
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1456-
auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1457-
1458-
if (!inputTy || !resultTy)
1459-
return rewriter.notifyMatchFailure(op,
1460-
"requires ranked input/output types");
1461-
1462-
auto batch = inputTy.getDimSize(0);
1463-
auto channels = inputTy.getDimSize(3);
1464-
auto inputH = inputTy.getDimSize(1);
1465-
auto inputW = inputTy.getDimSize(2);
1466-
auto outputH = resultTy.getDimSize(1);
1467-
auto outputW = resultTy.getDimSize(2);
1468-
1469-
if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1470-
return rewriter.notifyMatchFailure(
1471-
op, "tosa.resize has no broadcasting behavior");
1472-
1473-
// For any dimension that is broadcastable we generate a width of 1
1474-
// on the output.
1475-
llvm::SmallVector<int64_t> resizeShape;
1476-
resizeShape.push_back(batch);
1477-
resizeShape.push_back(inputH == 1 ? 1 : outputH);
1478-
resizeShape.push_back(inputW == 1 ? 1 : outputW);
1479-
resizeShape.push_back(channels);
1480-
1481-
auto resizeTy = resultTy.clone(resizeShape);
1482-
auto resize =
1483-
builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1484-
1485-
// Collapse an unit result dims.
1486-
SmallVector<ReassociationExprs, 4> reassociationMap(2);
1487-
reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1488-
reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1489-
if (inputH != 1)
1490-
reassociationMap.push_back({});
1491-
reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1492-
if (inputW != 1)
1493-
reassociationMap.push_back({});
1494-
reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1495-
1496-
llvm::SmallVector<int64_t> collapseShape = {batch};
1497-
if (inputH != 1)
1498-
collapseShape.push_back(outputH);
1499-
if (inputW != 1)
1500-
collapseShape.push_back(outputW);
1501-
collapseShape.push_back(channels);
1502-
1503-
auto collapseTy = resultTy.clone(collapseShape);
1504-
Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1505-
reassociationMap);
1506-
1507-
// Broadcast the collapsed shape to the output result.
1508-
llvm::SmallVector<Value> outputDynSize;
1509-
if (inputTy.isDynamicDim(0))
1510-
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1511-
if (inputTy.isDynamicDim(3))
1512-
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1513-
1514-
SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1515-
utils::IteratorType::parallel);
1516-
Value empty = builder.create<tensor::EmptyOp>(
1517-
resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1518-
1519-
SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1520-
if (inputH != 1)
1521-
inputExprs.push_back(rewriter.getAffineDimExpr(1));
1522-
if (inputW != 1)
1523-
inputExprs.push_back(rewriter.getAffineDimExpr(2));
1524-
inputExprs.push_back(rewriter.getAffineDimExpr(3));
1525-
1526-
auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1527-
inputExprs, rewriter.getContext());
1528-
1529-
auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1530-
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1531-
op, resultTy, ValueRange{collapse}, ValueRange{empty},
1532-
ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1533-
[=](OpBuilder &b, Location loc, ValueRange args) {
1534-
Value value = args[0];
1535-
b.create<linalg::YieldOp>(loc, value);
1536-
});
1537-
1538-
return success();
1539-
}
1540-
};
1541-
15421446
class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15431447
public:
15441448
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
@@ -1595,9 +1499,14 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15951499
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
15961500
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
15971501

1598-
ArrayRef<int64_t> offset = op.getOffset();
1599-
ArrayRef<int64_t> border = op.getBorder();
1600-
ArrayRef<int64_t> scale = op.getScale();
1502+
SmallVector<int64_t> scale, offset, border;
1503+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale) ||
1504+
!tosa::getConstShapeValue(op.getOffset().getDefiningOp(), offset) ||
1505+
!tosa::getConstShapeValue(op.getBorder().getDefiningOp(), border)) {
1506+
return rewriter.notifyMatchFailure(
1507+
op, "tosa.resize scale/offset/border should have compile time "
1508+
"constant values.");
1509+
}
16011510

16021511
Value yScaleN, yScaleD, xScaleN, xScaleD;
16031512
yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
@@ -2607,8 +2516,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
26072516
/*benefit=*/100);
26082517
patterns->add<ResizeUnaryConverter>(patterns->getContext(),
26092518
/*benefit=*/200);
2610-
patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2611-
/*benefit=*/300);
26122519

26132520
patterns->add<
26142521
// clang-format off

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,9 +955,22 @@ OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
955955
// Fold away cases where a tosa.resize operation returns a copy
956956
// of the input image.
957957
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
958-
ArrayRef<int64_t> offset = getOffset();
959-
ArrayRef<int64_t> border = getBorder();
960-
ArrayRef<int64_t> scale = getScale();
958+
auto scaleAttr =
959+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
960+
auto offsetAttr =
961+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
962+
auto borderAttr =
963+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
964+
if (!scaleAttr || !offsetAttr || !borderAttr) {
965+
return {};
966+
}
967+
968+
auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
969+
auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
970+
auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
971+
if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
972+
return {};
973+
}
961974

962975
// Check unit scaling.
963976
if (scale[0] != scale[1] || scale[2] != scale[3]) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,9 +1451,14 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
14511451
(inputWidth == ShapedType::kDynamic))
14521452
return failure();
14531453

1454-
llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1455-
llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1456-
llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1454+
SmallVector<int64_t> scaleInt, offsetInt, borderInt;
1455+
if (!tosa::getConstShapeValue(adaptor.getScale().getDefiningOp(), scaleInt) ||
1456+
!tosa::getConstShapeValue(adaptor.getOffset().getDefiningOp(),
1457+
offsetInt) ||
1458+
!tosa::getConstShapeValue(adaptor.getBorder().getDefiningOp(),
1459+
borderInt)) {
1460+
return failure();
1461+
}
14571462

14581463
// Compute the output shape based on attributes: scale, offset, and border.
14591464
outputShape[1] =
@@ -1470,6 +1475,81 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
14701475
return success();
14711476
}
14721477

1478+
LogicalResult tosa::ResizeOp::verify() {
1479+
const Value input = getInput();
1480+
const Value output = getOutput();
1481+
const RankedTensorType inputType = llvm::dyn_cast<RankedTensorType>(input.getType());
1482+
const RankedTensorType outputType = llvm::dyn_cast<RankedTensorType>(output.getType());
1483+
1484+
if (!inputType)
1485+
return emitOpError("expect a ranked input tensor");
1486+
if (!outputType)
1487+
return emitOpError("expect a ranked output tensor");
1488+
1489+
const int64_t oh = outputType.getDimSize(1);
1490+
const int64_t ow = outputType.getDimSize(2);
1491+
const int64_t ih = inputType.getDimSize(1);
1492+
const int64_t iw = inputType.getDimSize(2);
1493+
1494+
SmallVector<int64_t> scaleValues;
1495+
SmallVector<int64_t> offsetValues;
1496+
SmallVector<int64_t> borderValues;
1497+
if (!tosa::getConstShapeValue(getScale().getDefiningOp(), scaleValues) ||
1498+
!tosa::getConstShapeValue(getOffset().getDefiningOp(), offsetValues) ||
1499+
!tosa::getConstShapeValue(getBorder().getDefiningOp(), borderValues)) {
1500+
// Skip following checks if shape is not constant
1501+
return success();
1502+
}
1503+
1504+
if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
1505+
return emitOpError("expect all scale values to be > 0, got ") << scaleValues;
1506+
1507+
const int64_t scaleYN = scaleValues[0];
1508+
const int64_t scaleYD = scaleValues[1];
1509+
const int64_t scaleXN = scaleValues[2];
1510+
const int64_t scaleXD = scaleValues[3];
1511+
1512+
const int64_t offsetY = offsetValues[0];
1513+
const int64_t offsetX = offsetValues[1];
1514+
1515+
const int64_t borderY = borderValues[0];
1516+
const int64_t borderX = borderValues[1];
1517+
1518+
auto idivCheck = [](const int64_t lhs, const int64_t rhs) -> std::optional<int64_t> {
1519+
if (lhs % rhs != 0)
1520+
return std::nullopt;
1521+
return lhs / rhs;
1522+
};
1523+
1524+
if (ih != ShapedType::kDynamic) {
1525+
const std::optional<int64_t> calculatedOutHeightMinusOne = idivCheck(
1526+
(ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1527+
if (!calculatedOutHeightMinusOne.has_value())
1528+
return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + border_y ")
1529+
<< "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * " << scaleYN
1530+
<< " - " << offsetY << " + " << borderY << ") / " << scaleYD;
1531+
const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1532+
if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1533+
return emitOpError("calculated output height did not match expected: ")
1534+
<< "calculated=" << calculatedOutHeight << ", expected=" << oh;
1535+
}
1536+
1537+
if (iw != ShapedType::kDynamic) {
1538+
const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
1539+
const std::optional<int64_t> calculatedOutWidthMinusOne = idivCheck(scaledInWidth, scaleXD);
1540+
if (!calculatedOutWidthMinusOne.has_value())
1541+
return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + border_x ")
1542+
<< "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * " << scaleXN
1543+
<< " - " << offsetX << " + " << borderX << ") / " << scaleXD;
1544+
const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1545+
if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1546+
return emitOpError("calculated output width did not match expected: ")
1547+
<< "calculated=" << calculatedOutWidth << ", expected=" << ow;
1548+
}
1549+
1550+
return success();
1551+
}
1552+
14731553
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
14741554
MLIRContext *context, ::std::optional<Location> location,
14751555
ScatterOp::Adaptor adaptor,

0 commit comments

Comments
 (0)