@@ -428,6 +428,150 @@ static LogicalResult verifyConvOpModes(T op) {
428
428
return success ();
429
429
}
430
430
431
+ // ===----------------------------------------------------------------------===//
432
+ // ERROR_IF functions.
433
+ // ERROR_IF is a predicate that must set an error if the condition holds.
434
+ // ===----------------------------------------------------------------------===//
435
+
436
+ template <typename T>
437
+ static LogicalResult verifyConvOpErrorIf (T op) {
438
+ llvm::ArrayRef<int64_t > padding = op.getPad ();
439
+ if (llvm::any_of (padding, [](int64_t p) { return p < 0 ; }))
440
+ return op.emitOpError (" expect all padding values to be >= 0, got " )
441
+ << padding;
442
+
443
+ llvm::ArrayRef<int64_t > strides = op.getStride ();
444
+ if (llvm::any_of (strides, [](int64_t s) { return s < 1 ; }))
445
+ return op.emitOpError (" expect all stride values to be >= 1, got " )
446
+ << strides;
447
+
448
+ llvm::ArrayRef<int64_t > dilations = op.getDilation ();
449
+ if (llvm::any_of (dilations, [](int64_t d) { return d < 1 ; }))
450
+ return op.emitOpError (" expect all dilation values to be >= 1, got " )
451
+ << dilations;
452
+
453
+ const RankedTensorType outputType =
454
+ llvm::dyn_cast<RankedTensorType>(op.getOutput ().getType ());
455
+ if (!outputType)
456
+ // Skip following checks if output is not ranked
457
+ return success ();
458
+
459
+ const RankedTensorType inputType =
460
+ llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
461
+ const RankedTensorType weightType =
462
+ llvm::dyn_cast<RankedTensorType>(op.getWeight ().getType ());
463
+
464
+ if (inputType && weightType) {
465
+ const auto verifyOutputSize =
466
+ [&op](const int64_t inputSize, const int64_t kernelSize,
467
+ const int64_t outputSize, const int64_t padBefore,
468
+ const int64_t padAfter, const int64_t stride,
469
+ const int64_t dilation, const llvm::StringRef dimName,
470
+ const llvm::StringRef dimAxis,
471
+ const llvm::StringRef padBeforeName,
472
+ const llvm::StringRef padAfterName) -> LogicalResult {
473
+ if (inputSize == ShapedType::kDynamic ||
474
+ kernelSize == ShapedType::kDynamic )
475
+ return success ();
476
+
477
+ // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
478
+
479
+ const std::optional<int64_t > calculatedOutSizeMinusOne = idivCheck (
480
+ inputSize - 1 + padBefore + padAfter - (kernelSize - 1 ) * dilation,
481
+ stride);
482
+ if (!calculatedOutSizeMinusOne.has_value ())
483
+ return op.emitOpError (" expected input_" )
484
+ << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
485
+ << padAfterName << " - (kernel_" << dimName
486
+ << " - 1) * dilation_" << dimAxis
487
+ << " to be wholly divisible by stride_" << dimAxis << " , got ("
488
+ << inputSize << " - 1 + " << padBefore << " + " << padAfter
489
+ << " - (" << kernelSize << " - 1) * " << dilation << " ) / "
490
+ << stride;
491
+
492
+ const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value () + 1 ;
493
+ if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
494
+ return op.emitOpError (" calculated output " )
495
+ << dimName << " did not match expected: "
496
+ << " calculated=" << calculatedOutSize
497
+ << " , expected=" << outputSize;
498
+
499
+ return success ();
500
+ };
501
+
502
+ // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
503
+ if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
504
+ if (failed (verifyOutputSize (
505
+ inputType.getDimSize (1 ), weightType.getDimSize (1 ),
506
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
507
+ dilations[0 ], " height" , " y" , " top" , " bottom" )))
508
+ return failure ();
509
+
510
+ if (failed (verifyOutputSize (
511
+ inputType.getDimSize (2 ), weightType.getDimSize (2 ),
512
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
513
+ dilations[1 ], " width" , " x" , " left" , " right" )))
514
+ return failure ();
515
+ }
516
+
517
+ // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
518
+ if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
519
+ if (failed (verifyOutputSize (
520
+ inputType.getDimSize (1 ), weightType.getDimSize (0 ),
521
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
522
+ dilations[0 ], " height" , " y" , " top" , " bottom" )))
523
+ return failure ();
524
+
525
+ if (failed (verifyOutputSize (
526
+ inputType.getDimSize (2 ), weightType.getDimSize (1 ),
527
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
528
+ dilations[1 ], " width" , " x" , " left" , " right" )))
529
+ return failure ();
530
+ }
531
+
532
+ // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
533
+ if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
534
+ if (failed (verifyOutputSize (
535
+ inputType.getDimSize (1 ), weightType.getDimSize (1 ),
536
+ outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
537
+ dilations[0 ], " depth" , " d" , " front" , " back" )))
538
+ return failure ();
539
+
540
+ if (failed (verifyOutputSize (
541
+ inputType.getDimSize (2 ), weightType.getDimSize (2 ),
542
+ outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
543
+ dilations[1 ], " height" , " y" , " top" , " bottom" )))
544
+ return failure ();
545
+
546
+ if (failed (verifyOutputSize (
547
+ inputType.getDimSize (3 ), weightType.getDimSize (3 ),
548
+ outputType.getDimSize (3 ), padding[4 ], padding[5 ], strides[2 ],
549
+ dilations[2 ], " width" , " x" , " left" , " right" )))
550
+ return failure ();
551
+ }
552
+ }
553
+
554
+ const RankedTensorType biasType =
555
+ llvm::dyn_cast<RankedTensorType>(op.getBias ().getType ());
556
+ if (!biasType)
557
+ // Skip following checks if bias is not ranked
558
+ return success ();
559
+
560
+ const int64_t biasChannels = biasType.getDimSize (0 );
561
+ const int64_t outputChannels = outputType.getDimSize (3 );
562
+ if (biasChannels == ShapedType::kDynamic ||
563
+ outputChannels == ShapedType::kDynamic )
564
+ // Skip following checks if biasChannels or outputChannels is dynamic dim
565
+ return success ();
566
+
567
+ if (biasChannels != outputChannels && biasChannels != 1 )
568
+ return op.emitOpError (
569
+ " bias channels expected to be equal to output channels (" )
570
+ << outputChannels << " ) or 1, got " << biasChannels;
571
+
572
+ return success ();
573
+ }
574
+
431
575
// verify that inType and outType have same element types
432
576
template <typename T>
433
577
static LogicalResult verifySameElementTypes (T op, Type inType, Type outType) {
@@ -2580,99 +2724,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
2580
2724
}
2581
2725
2582
2726
LogicalResult Conv2DOp::verify () {
2583
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
2727
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
2728
+ verifyConvOpErrorIf (*this ).failed ())
2584
2729
return failure ();
2585
-
2586
- llvm::ArrayRef<int64_t > padding = getPad ();
2587
- if (llvm::any_of (padding, [](int64_t p) { return p < 0 ; }))
2588
- return emitOpError (" expect all padding values to be >= 0, got " ) << padding;
2589
-
2590
- llvm::ArrayRef<int64_t > strides = getStride ();
2591
- if (llvm::any_of (strides, [](int64_t s) { return s < 1 ; }))
2592
- return emitOpError (" expect all stride values to be >= 1, got " ) << strides;
2593
-
2594
- llvm::ArrayRef<int64_t > dilations = getDilation ();
2595
- if (llvm::any_of (dilations, [](int64_t d) { return d < 1 ; }))
2596
- return emitOpError (" expect all dilation values to be >= 1, got " )
2597
- << dilations;
2598
-
2599
- const RankedTensorType outputType =
2600
- llvm::dyn_cast<RankedTensorType>(getOutput ().getType ());
2601
- if (!outputType)
2602
- // Skip following checks if output is not ranked
2603
- return success ();
2604
-
2605
- const RankedTensorType inputType =
2606
- llvm::dyn_cast<RankedTensorType>(getInput ().getType ());
2607
- const RankedTensorType weightType =
2608
- llvm::dyn_cast<RankedTensorType>(getWeight ().getType ());
2609
-
2610
- if (inputType && weightType) {
2611
- const auto verifyOutputSize =
2612
- [this ](const int64_t inputSize, const int64_t kernelSize,
2613
- const int64_t outputSize, const int64_t padBefore,
2614
- const int64_t padAfter, const int64_t stride,
2615
- const int64_t dilation, const llvm::StringRef dimName,
2616
- const llvm::StringRef dimAxis,
2617
- const llvm::StringRef padBeforeName,
2618
- const llvm::StringRef padAfterName) -> LogicalResult {
2619
- if (inputSize == ShapedType::kDynamic ||
2620
- kernelSize == ShapedType::kDynamic )
2621
- return success ();
2622
-
2623
- const std::optional<int64_t > calculatedOutSizeMinusOne = idivCheck (
2624
- inputSize - 1 + padBefore + padAfter - (kernelSize - 1 ) * dilation,
2625
- stride);
2626
- if (!calculatedOutSizeMinusOne.has_value ())
2627
- return emitOpError (" expected input_" )
2628
- << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2629
- << padAfterName << " - (kernel_" << dimName
2630
- << " - 1) * dilation_" << dimAxis
2631
- << " to be wholly divisible by stride_" << dimAxis << " , got ("
2632
- << inputSize << " - 1 + " << padBefore << " + " << padAfter
2633
- << " - (" << kernelSize << " - 1) * " << dilation << " ) / "
2634
- << stride;
2635
-
2636
- const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value () + 1 ;
2637
- if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2638
- return emitOpError (" calculated output " )
2639
- << dimName << " did not match expected: "
2640
- << " calculated=" << calculatedOutSize
2641
- << " , expected=" << outputSize;
2642
-
2643
- return success ();
2644
- };
2645
-
2646
- if (failed (verifyOutputSize (
2647
- inputType.getDimSize (1 ), weightType.getDimSize (1 ),
2648
- outputType.getDimSize (1 ), padding[0 ], padding[1 ], strides[0 ],
2649
- dilations[0 ], " height" , " y" , " top" , " bottom" )))
2650
- return failure ();
2651
-
2652
- if (failed (verifyOutputSize (
2653
- inputType.getDimSize (2 ), weightType.getDimSize (2 ),
2654
- outputType.getDimSize (2 ), padding[2 ], padding[3 ], strides[1 ],
2655
- dilations[1 ], " width" , " x" , " left" , " right" )))
2656
- return failure ();
2657
- }
2658
-
2659
- const RankedTensorType biasType =
2660
- llvm::dyn_cast<RankedTensorType>(getBias ().getType ());
2661
- if (!biasType)
2662
- // Skip following checks if bias is not ranked
2663
- return success ();
2664
-
2665
- const int64_t biasChannels = biasType.getDimSize (0 );
2666
- const int64_t outputChannels = outputType.getDimSize (3 );
2667
- if (biasChannels == ShapedType::kDynamic ||
2668
- outputChannels == ShapedType::kDynamic )
2669
- // Skip following checks if biasChannels or outputChannels is dynamic dim
2670
- return success ();
2671
-
2672
- if (biasChannels != outputChannels && biasChannels != 1 )
2673
- return emitOpError (
2674
- " bias channels expected to be equal to output channels (" )
2675
- << outputChannels << " ) or 1, got " << biasChannels;
2676
2730
return success ();
2677
2731
}
2678
2732
@@ -2747,7 +2801,8 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
2747
2801
}
2748
2802
2749
2803
LogicalResult Conv3DOp::verify () {
2750
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
2804
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
2805
+ verifyConvOpErrorIf (*this ).failed ())
2751
2806
return failure ();
2752
2807
return success ();
2753
2808
}
@@ -2857,7 +2912,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2857
2912
}
2858
2913
2859
2914
LogicalResult DepthwiseConv2DOp::verify () {
2860
- if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed ())
2915
+ if (verifyConvOp (*this ).failed () || verifyConvOpModes (*this ).failed () ||
2916
+ verifyConvOpErrorIf (*this ).failed ())
2861
2917
return failure ();
2862
2918
return success ();
2863
2919
}
0 commit comments