@@ -527,96 +527,99 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
527
527
}
528
528
529
529
bool checkErrorIfResize (Operation *op) {
530
- if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
531
- const Value input = resize.getInput ();
532
- const Value output = resize.getOutput ();
533
- const RankedTensorType inputType =
534
- llvm::dyn_cast<RankedTensorType>(input.getType ());
535
- const RankedTensorType outputType =
536
- llvm::dyn_cast<RankedTensorType>(output.getType ());
537
-
538
- if (!inputType || !outputType) {
539
- op->emitOpError (" expect ranked input/output tensor" );
540
- return false ;
541
- }
530
+ auto resize = dyn_cast<tosa::ResizeOp>(op);
531
+ if (!resize)
532
+ return true ;
542
533
543
- // Ensure the image size is supported by GPU APIs and that for integer
544
- // implementations, position * stride does not overflow int32_t.
545
- if (inputType.hasStaticShape () && outputType.hasStaticShape ()) {
546
- const SmallVector<int64_t , 4 > sizes = {
547
- outputType.getDimSize (1 ), outputType.getDimSize (2 ),
548
- inputType.getDimSize (1 ), inputType.getDimSize (2 )};
549
- const int64_t *maxDim = llvm::max_element (sizes);
550
- if (maxDim != sizes.end () && *maxDim >= 16384 ) {
551
- op->emitOpError (" expect input/output height/width dims to be < 16384, " )
552
- << " got [OH, OW, IH, IW] = " << sizes;
553
- return false ;
554
- }
555
- }
534
+ const Value input = resize.getInput ();
535
+ const Value output = resize.getOutput ();
536
+ const RankedTensorType inputType =
537
+ llvm::dyn_cast<RankedTensorType>(input.getType ());
538
+ const RankedTensorType outputType =
539
+ llvm::dyn_cast<RankedTensorType>(output.getType ());
556
540
557
- SmallVector<int64_t > scale;
558
- if (!tosa::getConstShapeValue (resize.getScale ().getDefiningOp (), scale)) {
541
+ if (!inputType || !outputType) {
542
+ op->emitOpError (" expect ranked input/output tensor" );
543
+ return false ;
544
+ }
545
+
546
+ // Ensure the image size is supported by GPU APIs and that for integer
547
+ // implementations, position * stride does not overflow int32_t.
548
+ if (inputType.hasStaticShape () && outputType.hasStaticShape ()) {
549
+ const SmallVector<int64_t , 4 > sizes = {
550
+ outputType.getDimSize (1 ), outputType.getDimSize (2 ),
551
+ inputType.getDimSize (1 ), inputType.getDimSize (2 )};
552
+ const int64_t *maxDim = llvm::max_element (sizes);
553
+ if (maxDim != sizes.end () && *maxDim >= 16384 ) {
554
+ op->emitOpError (" expect input/output height/width dims to be < 16384, " )
555
+ << " got [OH, OW, IH, IW] = " << sizes;
559
556
return false ;
560
557
}
558
+ }
561
559
562
- const int64_t scaleYN = scale[ 0 ] ;
563
- const int64_t scaleYD = scale[ 1 ];
564
- const int64_t scaleXN = scale[ 2 ] ;
565
- const int64_t scaleXD = scale[ 3 ];
560
+ SmallVector< int64_t > scale;
561
+ if (! tosa::getConstShapeValue (resize. getScale (). getDefiningOp (), scale)) {
562
+ return false ;
563
+ }
566
564
567
- // Ensure scale values don't overflow int32 accumulator
568
- if (scaleYN > (1 << 11 ) || scaleXN > (1 << 11 )) {
569
- op->emitOpError (" expect all scale numerator values to be <= (1 << 11), "
570
- " got scale_y_n=" )
571
- << scaleYN << " , scale_x_n=" << scaleXN;
572
- return false ;
573
- }
565
+ const int64_t scaleYN = scale[0 ];
566
+ const int64_t scaleYD = scale[1 ];
567
+ const int64_t scaleXN = scale[2 ];
568
+ const int64_t scaleXD = scale[3 ];
574
569
575
- if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
576
- op->emitOpError (" expect a downscale ratio larger than 1/16, got y=" )
577
- << scaleYN << " /" << scaleYD << " , x=" << scaleXN << " /" << scaleXD;
578
- return false ;
579
- }
570
+ // Ensure scale values don't overflow int32 accumulator
571
+ if (scaleYN > (1 << 11 ) || scaleXN > (1 << 11 )) {
572
+ op->emitOpError (" expect all scale numerator values to be <= (1 << 11), "
573
+ " got scale_y_n=" )
574
+ << scaleYN << " , scale_x_n=" << scaleXN;
575
+ return false ;
576
+ }
580
577
581
- SmallVector<int64_t > offset;
582
- SmallVector<int64_t > border;
583
- if (!tosa::getConstShapeValue (resize.getOffset ().getDefiningOp (), offset) ||
584
- !tosa::getConstShapeValue (resize.getBorder ().getDefiningOp (), border)) {
585
- return false ;
586
- }
578
+ if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
579
+ op->emitOpError (" expect a downscale ratio larger than 1/16, got y=" )
580
+ << scaleYN << " /" << scaleYD << " , x=" << scaleXN << " /" << scaleXD;
581
+ return false ;
582
+ }
587
583
588
- const int64_t offsetY = offset[0 ];
589
- const int64_t offsetX = offset[1 ];
590
- const int64_t borderY = border[0 ];
591
- const int64_t borderX = border[1 ];
592
-
593
- // Set a consistent lower limit of 1/16 downscale to simplify
594
- // implementations
595
- if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
596
- op->emitOpError (
597
- " expect offsetY / scaleYNumerator to be in range [-1, 16), got " )
598
- << offsetY << " /" << scaleYN;
599
- return false ;
600
- }
601
- if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
602
- op->emitOpError (
603
- " expect offsetX / scaleXNumerator to be in range [-1, 16), got " )
604
- << offsetX << " /" << scaleXN;
605
- return false ;
606
- }
607
- if (borderY < -16 * scaleYN || borderY >= scaleYN) {
608
- op->emitOpError (
609
- " expect borderY / scaleYNumerator to be in range [-16, 1), got " )
610
- << borderY << " /" << scaleYN;
611
- return false ;
612
- }
613
- if (borderX < -16 * scaleXN || borderX >= scaleXN) {
614
- op->emitOpError (
615
- " expect borderX / scaleXNumerator to be in range [-16, 1), got " )
616
- << borderX << " /" << scaleXN;
617
- return false ;
618
- }
584
+ SmallVector<int64_t > offset;
585
+ SmallVector<int64_t > border;
586
+ if (!tosa::getConstShapeValue (resize.getOffset ().getDefiningOp (), offset) ||
587
+ !tosa::getConstShapeValue (resize.getBorder ().getDefiningOp (), border)) {
588
+ return false ;
619
589
}
590
+
591
+ const int64_t offsetY = offset[0 ];
592
+ const int64_t offsetX = offset[1 ];
593
+ // Set a consistent lower limit of 1/16 downscale to simplify
594
+ // implementations
595
+ if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
596
+ op->emitOpError (
597
+ " expect offsetY / scaleYNumerator to be in range [-1, 16), got " )
598
+ << offsetY << " /" << scaleYN;
599
+ return false ;
600
+ }
601
+ if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
602
+ op->emitOpError (
603
+ " expect offsetX / scaleXNumerator to be in range [-1, 16), got " )
604
+ << offsetX << " /" << scaleXN;
605
+ return false ;
606
+ }
607
+
608
+ const int64_t borderY = border[0 ];
609
+ const int64_t borderX = border[1 ];
610
+ if (borderY < -16 * scaleYN || borderY >= scaleYN) {
611
+ op->emitOpError (
612
+ " expect borderY / scaleYNumerator to be in range [-16, 1), got " )
613
+ << borderY << " /" << scaleYN;
614
+ return false ;
615
+ }
616
+ if (borderX < -16 * scaleXN || borderX >= scaleXN) {
617
+ op->emitOpError (
618
+ " expect borderX / scaleXNumerator to be in range [-16, 1), got " )
619
+ << borderX << " /" << scaleXN;
620
+ return false ;
621
+ }
622
+
620
623
return true ;
621
624
}
622
625
0 commit comments