Skip to content

Commit 1397556

Browse files
HsiangkaiTai78641lhutton1
committed
[mlir][tosa] Make TOSA RESIZE's scale, offset, border as Input
Move the `scale`, `offset`, and `border` parameters of the RESIZE operator in the MLIR TOSA dialect from attributes to inputs and update lit tests appropriately. Add the verifier of the `tosa::ResizeOp` operation. Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
1 parent 0e779ad commit 1397556

File tree

13 files changed

+446
-114
lines changed

13 files changed

+446
-114
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,9 +1822,9 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18221822

18231823
let arguments = (ins
18241824
Tosa_Tensor4D:$input,
1825-
Tosa_IntArrayAttr4:$scale,
1826-
Tosa_IntArrayAttr2:$offset,
1827-
Tosa_IntArrayAttr2:$border,
1825+
Rank4TosaShape:$scale,
1826+
Rank2TosaShape:$offset,
1827+
Rank2TosaShape:$border,
18281828
Tosa_ResizeTypeAttr:$mode
18291829
);
18301830

@@ -1833,6 +1833,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18331833
);
18341834

18351835
let hasFolder = 1;
1836+
let hasVerifier = 1;
18361837
}
18371838

18381839
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
240240
bool getConstShapeValue(Operation *op,
241241
llvm::SmallVector<int64_t> &result_shape);
242242

243+
// returns a small vector of int64_t values that attr contains
244+
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
245+
const int rank);
243246
} // namespace tosa
244247
} // namespace mlir
245248

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,10 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
13871387
return success();
13881388
}
13891389

1390-
ArrayRef<int64_t> scale = op.getScale();
1390+
SmallVector<int64_t> scale;
1391+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale)) {
1392+
return failure();
1393+
}
13911394

13921395
// Collapse the unit width and height away.
13931396
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1488,8 +1491,9 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
14881491
resizeShape.push_back(channels);
14891492

14901493
auto resizeTy = resultTy.clone(resizeShape);
1491-
auto resize =
1492-
builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1494+
auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1495+
op.getOffset(), op.getBorder(),
1496+
op.getMode());
14931497

14941498
// Collapse an unit result dims.
14951499
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1604,9 +1608,14 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
16041608
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
16051609
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
16061610

1607-
ArrayRef<int64_t> offset = op.getOffset();
1608-
ArrayRef<int64_t> border = op.getBorder();
1609-
ArrayRef<int64_t> scale = op.getScale();
1611+
SmallVector<int64_t> scale, offset, border;
1612+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale) ||
1613+
!tosa::getConstShapeValue(op.getOffset().getDefiningOp(), offset) ||
1614+
!tosa::getConstShapeValue(op.getBorder().getDefiningOp(), border)) {
1615+
return rewriter.notifyMatchFailure(
1616+
op, "tosa.resize scale/offset/border should have compile time "
1617+
"constant values.");
1618+
}
16101619

16111620
Value yScaleN, yScaleD, xScaleN, xScaleD;
16121621
yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,9 +1034,22 @@ OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
10341034
// Fold away cases where a tosa.resize operation returns a copy
10351035
// of the input image.
10361036
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1037-
ArrayRef<int64_t> offset = getOffset();
1038-
ArrayRef<int64_t> border = getBorder();
1039-
ArrayRef<int64_t> scale = getScale();
1037+
auto scaleAttr =
1038+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1039+
auto offsetAttr =
1040+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1041+
auto borderAttr =
1042+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1043+
if (!scaleAttr || !offsetAttr || !borderAttr) {
1044+
return {};
1045+
}
1046+
1047+
auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1048+
auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1049+
auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1050+
if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1051+
return {};
1052+
}
10401053

10411054
// Check unit scaling.
10421055
if (scale[0] != scale[1] || scale[2] != scale[3]) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,9 +1685,14 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
16851685
(inputWidth == ShapedType::kDynamic))
16861686
return failure();
16871687

1688-
llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1689-
llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1690-
llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1688+
SmallVector<int64_t> scaleInt, offsetInt, borderInt;
1689+
if (!tosa::getConstShapeValue(adaptor.getScale().getDefiningOp(), scaleInt) ||
1690+
!tosa::getConstShapeValue(adaptor.getOffset().getDefiningOp(),
1691+
offsetInt) ||
1692+
!tosa::getConstShapeValue(adaptor.getBorder().getDefiningOp(),
1693+
borderInt)) {
1694+
return failure();
1695+
}
16911696

16921697
// Compute the output shape based on attributes: scale, offset, and border.
16931698
outputShape[1] =
@@ -1704,6 +1709,90 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
17041709
return success();
17051710
}
17061711

1712+
LogicalResult tosa::ResizeOp::verify() {
1713+
const Value input = getInput();
1714+
const Value output = getOutput();
1715+
const RankedTensorType inputType =
1716+
llvm::dyn_cast<RankedTensorType>(input.getType());
1717+
const RankedTensorType outputType =
1718+
llvm::dyn_cast<RankedTensorType>(output.getType());
1719+
1720+
if (!inputType)
1721+
return emitOpError("expect a ranked input tensor");
1722+
if (!outputType)
1723+
return emitOpError("expect a ranked output tensor");
1724+
1725+
const int64_t oh = outputType.getDimSize(1);
1726+
const int64_t ow = outputType.getDimSize(2);
1727+
const int64_t ih = inputType.getDimSize(1);
1728+
const int64_t iw = inputType.getDimSize(2);
1729+
1730+
SmallVector<int64_t> scaleValues;
1731+
SmallVector<int64_t> offsetValues;
1732+
SmallVector<int64_t> borderValues;
1733+
if (!tosa::getConstShapeValue(getScale().getDefiningOp(), scaleValues) ||
1734+
!tosa::getConstShapeValue(getOffset().getDefiningOp(), offsetValues) ||
1735+
!tosa::getConstShapeValue(getBorder().getDefiningOp(), borderValues)) {
1736+
// Skip following checks if shape is not constant
1737+
return success();
1738+
}
1739+
1740+
if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
1741+
return emitOpError("expect all scale values to be > 0, got ")
1742+
<< scaleValues;
1743+
1744+
const int64_t scaleYN = scaleValues[0];
1745+
const int64_t scaleYD = scaleValues[1];
1746+
const int64_t scaleXN = scaleValues[2];
1747+
const int64_t scaleXD = scaleValues[3];
1748+
1749+
const int64_t offsetY = offsetValues[0];
1750+
const int64_t offsetX = offsetValues[1];
1751+
1752+
const int64_t borderY = borderValues[0];
1753+
const int64_t borderX = borderValues[1];
1754+
1755+
auto idivCheck = [](const int64_t lhs,
1756+
const int64_t rhs) -> std::optional<int64_t> {
1757+
if (lhs % rhs != 0)
1758+
return std::nullopt;
1759+
return lhs / rhs;
1760+
};
1761+
1762+
if (ih != ShapedType::kDynamic) {
1763+
const std::optional<int64_t> calculatedOutHeightMinusOne =
1764+
idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1765+
if (!calculatedOutHeightMinusOne.has_value())
1766+
return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
1767+
"border_y ")
1768+
<< "to be wholly divisible by scale_y_d, got ((" << ih
1769+
<< " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1770+
<< ") / " << scaleYD;
1771+
const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1772+
if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1773+
return emitOpError("calculated output height did not match expected: ")
1774+
<< "calculated=" << calculatedOutHeight << ", expected=" << oh;
1775+
}
1776+
1777+
if (iw != ShapedType::kDynamic) {
1778+
const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
1779+
const std::optional<int64_t> calculatedOutWidthMinusOne =
1780+
idivCheck(scaledInWidth, scaleXD);
1781+
if (!calculatedOutWidthMinusOne.has_value())
1782+
return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
1783+
"border_x ")
1784+
<< "to be wholly divisible by scale_x_d, got ((" << iw
1785+
<< " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1786+
<< ") / " << scaleXD;
1787+
const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1788+
if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1789+
return emitOpError("calculated output width did not match expected: ")
1790+
<< "calculated=" << calculatedOutWidth << ", expected=" << ow;
1791+
}
1792+
1793+
return success();
1794+
}
1795+
17071796
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
17081797
MLIRContext *context, ::std::optional<Location> location,
17091798
ScatterOp::Adaptor adaptor,

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
21+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
2122
#include "mlir/IR/Builders.h"
2223
#include "mlir/IR/BuiltinOps.h"
2324
#include "mlir/IR/Matchers.h"
@@ -119,6 +120,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
119120
// check variable read/write data types against variable declarations
120121
LogicalResult applyVariableCheck(Operation *op);
121122

123+
// check error if conditions
124+
LogicalResult applyErrorIfCheck(Operation *op);
125+
122126
private:
123127
void populateConstantOperandChecks() {
124128
constCheckers.emplace_back(checkConstantOperandPad);
@@ -383,11 +387,14 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
383387
// Resize op: level check max scales
384388
bool levelCheckResize(Operation *op) {
385389
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
386-
auto scale = resize.getScale();
387-
int16_t scaleYN = scale[0];
388-
int16_t scaleYD = scale[1];
389-
int16_t scaleXN = scale[2];
390-
int16_t scaleXD = scale[3];
390+
SmallVector<int64_t> scale;
391+
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
392+
return false;
393+
}
394+
const int64_t scaleYN = scale[0];
395+
const int64_t scaleYD = scale[1];
396+
const int64_t scaleXN = scale[2];
397+
const int64_t scaleXD = scale[3];
391398
if (!levelCheckScale(op, scaleYN / scaleYD,
392399
"scale_y_n/scale_y_d <= MAX_SCALE") ||
393400
!levelCheckScale(op, scaleXN / scaleXD,
@@ -519,6 +526,106 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
519526
return success();
520527
}
521528

529+
bool checkErrorIfResize(Operation *op) {
530+
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
531+
const Value input = resize.getInput();
532+
const Value output = resize.getOutput();
533+
const RankedTensorType inputType =
534+
llvm::dyn_cast<RankedTensorType>(input.getType());
535+
const RankedTensorType outputType =
536+
llvm::dyn_cast<RankedTensorType>(output.getType());
537+
538+
if (!inputType || !outputType) {
539+
op->emitOpError("expect ranked input/output tensor");
540+
return false;
541+
}
542+
543+
// Ensure the image size is supported by GPU APIs and that for integer
544+
// implementations, position * stride does not overflow int32_t.
545+
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
546+
const SmallVector<int64_t, 4> sizes = {
547+
outputType.getDimSize(1), outputType.getDimSize(2),
548+
inputType.getDimSize(1), inputType.getDimSize(2)};
549+
const int64_t *maxDim = llvm::max_element(sizes);
550+
if (maxDim != sizes.end() && *maxDim >= 16384) {
551+
op->emitOpError("expect input/output height/width dims to be < 16384, ")
552+
<< "got [OH, OW, IH, IW] = " << sizes;
553+
return false;
554+
}
555+
}
556+
557+
SmallVector<int64_t> scale;
558+
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
559+
return false;
560+
}
561+
562+
const int64_t scaleYN = scale[0];
563+
const int64_t scaleYD = scale[1];
564+
const int64_t scaleXN = scale[2];
565+
const int64_t scaleXD = scale[3];
566+
567+
// Ensure scale values don't overflow int32 accumulator
568+
if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
569+
op->emitOpError("expect all scale numerator values to be <= (1 << 11), "
570+
"got scale_y_n=")
571+
<< scaleYN << ", scale_x_n=" << scaleXN;
572+
return false;
573+
}
574+
575+
if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
576+
op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
577+
<< scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
578+
return false;
579+
}
580+
581+
SmallVector<int64_t> offset;
582+
SmallVector<int64_t> border;
583+
if (!tosa::getConstShapeValue(resize.getOffset().getDefiningOp(), offset) ||
584+
!tosa::getConstShapeValue(resize.getBorder().getDefiningOp(), border)) {
585+
return false;
586+
}
587+
588+
const int64_t offsetY = offset[0];
589+
const int64_t offsetX = offset[1];
590+
const int64_t borderY = border[0];
591+
const int64_t borderX = border[1];
592+
593+
// Set a consistent lower limit of 1/16 downscale to simplify
594+
// implementations
595+
if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
596+
op->emitOpError(
597+
"expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
598+
<< offsetY << "/" << scaleYN;
599+
return false;
600+
}
601+
if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
602+
op->emitOpError(
603+
"expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
604+
<< offsetX << "/" << scaleXN;
605+
return false;
606+
}
607+
if (borderY < -16 * scaleYN || borderY >= scaleYN) {
608+
op->emitOpError(
609+
"expect borderY / scaleYNumerator to be in range [-16, 1), got ")
610+
<< borderY << "/" << scaleYN;
611+
return false;
612+
}
613+
if (borderX < -16 * scaleXN || borderX >= scaleXN) {
614+
op->emitOpError(
615+
"expect borderX / scaleXNumerator to be in range [-16, 1), got ")
616+
<< borderX << "/" << scaleXN;
617+
return false;
618+
}
619+
}
620+
return true;
621+
}
622+
623+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
624+
if (!checkErrorIfResize(op))
625+
return failure();
626+
return success();
627+
}
628+
522629
bool TosaValidation::isValidElementType(Type type) {
523630
if (isa<FloatType>(type)) {
524631
if (!isEnabledProfile(TosaProfileEnum::MainInference))
@@ -582,6 +689,10 @@ void TosaValidation::runOnOperation() {
582689
// do variable type checks
583690
if (failed(applyVariableCheck(op)))
584691
signalPassFailure();
692+
693+
// do error if checks
694+
if (StrictOperationSpecAlignment && failed(applyErrorIfCheck(op)))
695+
signalPassFailure();
585696
});
586697
}
587698
} // namespace

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,21 @@ bool mlir::tosa::getConstShapeValue(Operation *op,
198198
// for undefined op, return false.
199199
return false;
200200
}
201+
202+
// returns a small vector of int64_t values that attr contains
203+
SmallVector<int64_t>
204+
mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
205+
if (attr.isSplat()) {
206+
int64_t v = attr.getSplatValue<APInt>().getSExtValue();
207+
return SmallVector<int64_t>(rank, v);
208+
}
209+
210+
if (auto int_array_attr = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
211+
SmallVector<int64_t> vec;
212+
for (APInt val : int_array_attr.getValues<APInt>()) {
213+
vec.push_back(val.getSExtValue());
214+
}
215+
return vec;
216+
}
217+
return {};
218+
}

0 commit comments

Comments
 (0)