Skip to content

Commit 2c9578b

Browse files
committed
Address comments
1 parent 1397556 commit 2c9578b

File tree

1 file changed

+83
-80
lines changed

1 file changed

+83
-80
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 83 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -527,96 +527,99 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
527527
}
528528

529529
bool checkErrorIfResize(Operation *op) {
530-
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
531-
const Value input = resize.getInput();
532-
const Value output = resize.getOutput();
533-
const RankedTensorType inputType =
534-
llvm::dyn_cast<RankedTensorType>(input.getType());
535-
const RankedTensorType outputType =
536-
llvm::dyn_cast<RankedTensorType>(output.getType());
537-
538-
if (!inputType || !outputType) {
539-
op->emitOpError("expect ranked input/output tensor");
540-
return false;
541-
}
530+
auto resize = dyn_cast<tosa::ResizeOp>(op);
531+
if (!resize)
532+
return true;
542533

543-
// Ensure the image size is supported by GPU APIs and that for integer
544-
// implementations, position * stride does not overflow int32_t.
545-
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
546-
const SmallVector<int64_t, 4> sizes = {
547-
outputType.getDimSize(1), outputType.getDimSize(2),
548-
inputType.getDimSize(1), inputType.getDimSize(2)};
549-
const int64_t *maxDim = llvm::max_element(sizes);
550-
if (maxDim != sizes.end() && *maxDim >= 16384) {
551-
op->emitOpError("expect input/output height/width dims to be < 16384, ")
552-
<< "got [OH, OW, IH, IW] = " << sizes;
553-
return false;
554-
}
555-
}
534+
const Value input = resize.getInput();
535+
const Value output = resize.getOutput();
536+
const RankedTensorType inputType =
537+
llvm::dyn_cast<RankedTensorType>(input.getType());
538+
const RankedTensorType outputType =
539+
llvm::dyn_cast<RankedTensorType>(output.getType());
556540

557-
SmallVector<int64_t> scale;
558-
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
541+
if (!inputType || !outputType) {
542+
op->emitOpError("expect ranked input/output tensor");
543+
return false;
544+
}
545+
546+
// Ensure the image size is supported by GPU APIs and that for integer
547+
// implementations, position * stride does not overflow int32_t.
548+
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
549+
const SmallVector<int64_t, 4> sizes = {
550+
outputType.getDimSize(1), outputType.getDimSize(2),
551+
inputType.getDimSize(1), inputType.getDimSize(2)};
552+
const int64_t *maxDim = llvm::max_element(sizes);
553+
if (maxDim != sizes.end() && *maxDim >= 16384) {
554+
op->emitOpError("expect input/output height/width dims to be < 16384, ")
555+
<< "got [OH, OW, IH, IW] = " << sizes;
559556
return false;
560557
}
558+
}
561559

562-
const int64_t scaleYN = scale[0];
563-
const int64_t scaleYD = scale[1];
564-
const int64_t scaleXN = scale[2];
565-
const int64_t scaleXD = scale[3];
560+
SmallVector<int64_t> scale;
561+
if (!tosa::getConstShapeValue(resize.getScale().getDefiningOp(), scale)) {
562+
return false;
563+
}
566564

567-
// Ensure scale values don't overflow int32 accumulator
568-
if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
569-
op->emitOpError("expect all scale numerator values to be <= (1 << 11), "
570-
"got scale_y_n=")
571-
<< scaleYN << ", scale_x_n=" << scaleXN;
572-
return false;
573-
}
565+
const int64_t scaleYN = scale[0];
566+
const int64_t scaleYD = scale[1];
567+
const int64_t scaleXN = scale[2];
568+
const int64_t scaleXD = scale[3];
574569

575-
if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
576-
op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
577-
<< scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
578-
return false;
579-
}
570+
// Ensure scale values don't overflow int32 accumulator
571+
if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
572+
op->emitOpError("expect all scale numerator values to be <= (1 << 11), "
573+
"got scale_y_n=")
574+
<< scaleYN << ", scale_x_n=" << scaleXN;
575+
return false;
576+
}
580577

581-
SmallVector<int64_t> offset;
582-
SmallVector<int64_t> border;
583-
if (!tosa::getConstShapeValue(resize.getOffset().getDefiningOp(), offset) ||
584-
!tosa::getConstShapeValue(resize.getBorder().getDefiningOp(), border)) {
585-
return false;
586-
}
578+
if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
579+
op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
580+
<< scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
581+
return false;
582+
}
587583

588-
const int64_t offsetY = offset[0];
589-
const int64_t offsetX = offset[1];
590-
const int64_t borderY = border[0];
591-
const int64_t borderX = border[1];
592-
593-
// Set a consistent lower limit of 1/16 downscale to simplify
594-
// implementations
595-
if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
596-
op->emitOpError(
597-
"expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
598-
<< offsetY << "/" << scaleYN;
599-
return false;
600-
}
601-
if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
602-
op->emitOpError(
603-
"expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
604-
<< offsetX << "/" << scaleXN;
605-
return false;
606-
}
607-
if (borderY < -16 * scaleYN || borderY >= scaleYN) {
608-
op->emitOpError(
609-
"expect borderY / scaleYNumerator to be in range [-16, 1), got ")
610-
<< borderY << "/" << scaleYN;
611-
return false;
612-
}
613-
if (borderX < -16 * scaleXN || borderX >= scaleXN) {
614-
op->emitOpError(
615-
"expect borderX / scaleXNumerator to be in range [-16, 1), got ")
616-
<< borderX << "/" << scaleXN;
617-
return false;
618-
}
584+
SmallVector<int64_t> offset;
585+
SmallVector<int64_t> border;
586+
if (!tosa::getConstShapeValue(resize.getOffset().getDefiningOp(), offset) ||
587+
!tosa::getConstShapeValue(resize.getBorder().getDefiningOp(), border)) {
588+
return false;
619589
}
590+
591+
const int64_t offsetY = offset[0];
592+
const int64_t offsetX = offset[1];
593+
// Set a consistent lower limit of 1/16 downscale to simplify
594+
// implementations
595+
if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
596+
op->emitOpError(
597+
"expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
598+
<< offsetY << "/" << scaleYN;
599+
return false;
600+
}
601+
if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
602+
op->emitOpError(
603+
"expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
604+
<< offsetX << "/" << scaleXN;
605+
return false;
606+
}
607+
608+
const int64_t borderY = border[0];
609+
const int64_t borderX = border[1];
610+
if (borderY < -16 * scaleYN || borderY >= scaleYN) {
611+
op->emitOpError(
612+
"expect borderY / scaleYNumerator to be in range [-16, 1), got ")
613+
<< borderY << "/" << scaleYN;
614+
return false;
615+
}
616+
if (borderX < -16 * scaleXN || borderX >= scaleXN) {
617+
op->emitOpError(
618+
"expect borderX / scaleXNumerator to be in range [-16, 1), got ")
619+
<< borderX << "/" << scaleXN;
620+
return false;
621+
}
622+
620623
return true;
621624
}
622625

0 commit comments

Comments
 (0)