@@ -573,6 +573,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
573
573
return success ();
574
574
}
575
575
576
+ // Verify whether same type and shape of the given two types.
577
+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, Type type1,
578
+ StringRef name1, Type type2,
579
+ StringRef name2) {
580
+ auto shapeType1 = dyn_cast<ShapedType>(type1);
581
+ auto shapeType2 = dyn_cast<ShapedType>(type2);
582
+ if (!shapeType1 || !shapeType2)
583
+ return failure ();
584
+
585
+ auto elemType1 = shapeType1.getElementType ();
586
+ auto elemType2 = shapeType2.getElementType ();
587
+ if (elemType1 != elemType2)
588
+ return op->emitOpError ()
589
+ << " require same element type for " << name1 << " (" << elemType1
590
+ << " ) and " << name2 << " (" << elemType2 << " )" ;
591
+
592
+ if (failed (verifyCompatibleShape (type1, type2)))
593
+ return op->emitOpError ()
594
+ << " require same shapes for " << name1 << " (" << type1 << " ) and "
595
+ << name2 << " (" << type2 << " )" ;
596
+
597
+ return success ();
598
+ }
599
+
600
+ // Verify whether same length, type, and shape of the given two tensor lists.
601
+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, ValueRange list1,
602
+ StringRef name1,
603
+ ValueRange list2,
604
+ StringRef name2) {
605
+ if (list1.size () != list2.size ())
606
+ return op->emitOpError ()
607
+ << " require same number of values in " << name1 << " ("
608
+ << list1.size () << " ) and " << name2 << " (" << list2.size () << " )" ;
609
+
610
+ for (auto [type1, type2] :
611
+ llvm::zip_equal (list1.getTypes (), list2.getTypes ())) {
612
+ if (errorIfTypeOrShapeMismatch (op, type1, name1, type2, name2).failed ())
613
+ return failure ();
614
+ }
615
+
616
+ return success ();
617
+ }
618
+
619
+ static inline LogicalResult errorIfShapeNotSizeOne (Operation *op, Type type) {
620
+ ShapeAdaptor shapeAdaptor (type);
621
+ if (!shapeAdaptor.hasRank () || !shapeAdaptor.hasStaticShape ())
622
+ return success ();
623
+
624
+ return shapeAdaptor.getNumElements () == 1 ? success () : failure ();
625
+ }
626
+
576
627
// verify that inType and outType have same element types
577
628
template <typename T>
578
629
static LogicalResult verifySameElementTypes (T op, Type inType, Type outType) {
@@ -3601,6 +3652,84 @@ void IfOp::print(OpAsmPrinter &p) {
3601
3652
p.printOptionalAttrDict ((*this )->getAttrs ());
3602
3653
}
3603
3654
3655
+ LogicalResult IfOp::verify () {
3656
+ if (errorIfTypeOrShapeMismatch (*this , getThenGraph ().front ().getArguments (),
3657
+ " 'then_graph' arguments" , getInputList (),
3658
+ " 'input_list'" )
3659
+ .failed ())
3660
+ return failure ();
3661
+
3662
+ if (errorIfTypeOrShapeMismatch (*this , getElseGraph ().front ().getArguments (),
3663
+ " 'else_graph' arguments" , getInputList (),
3664
+ " 'input_list'" )
3665
+ .failed ())
3666
+ return failure ();
3667
+
3668
+ auto thenYield = cast<tosa::YieldOp>(getThenGraph ().front ().getTerminator ());
3669
+ if (errorIfTypeOrShapeMismatch (*this , thenYield.getInputs (),
3670
+ " 'then_graph' results" , getOutputList (),
3671
+ " 'output_list'" )
3672
+ .failed ())
3673
+ return failure ();
3674
+
3675
+ auto elseYield = cast<tosa::YieldOp>(getElseGraph ().front ().getTerminator ());
3676
+ if (errorIfTypeOrShapeMismatch (*this , elseYield.getInputs (),
3677
+ " 'else_graph' results" , getOutputList (),
3678
+ " 'output_list'" )
3679
+ .failed ())
3680
+ return failure ();
3681
+
3682
+ auto condType = getCondition ().getType ();
3683
+ if (errorIfShapeNotSizeOne (*this , condType).failed ())
3684
+ return emitOpError () << " 'condition' must be a size 1 tensor, got "
3685
+ << condType;
3686
+
3687
+ return success ();
3688
+ }
3689
+
3690
+ LogicalResult WhileOp::verify () {
3691
+ if (errorIfTypeOrShapeMismatch (*this , getInputList (), " 'input_list'" ,
3692
+ getOutputList (), " 'output_list'" )
3693
+ .failed ())
3694
+ return failure ();
3695
+
3696
+ if (errorIfTypeOrShapeMismatch (*this , getCondGraph ().front ().getArguments (),
3697
+ " 'cond_graph' arguments" , getInputList (),
3698
+ " 'input_list'" )
3699
+ .failed ())
3700
+ return failure ();
3701
+
3702
+ if (errorIfTypeOrShapeMismatch (*this , getBodyGraph ().front ().getArguments (),
3703
+ " 'body_graph' arguments" , getInputList (),
3704
+ " 'input_list'" )
3705
+ .failed ())
3706
+ return failure ();
3707
+
3708
+ auto bodyYield = cast<tosa::YieldOp>(getBodyGraph ().front ().getTerminator ());
3709
+ if (errorIfTypeOrShapeMismatch (*this , bodyYield.getInputs (),
3710
+ " 'body_graph' results" , getInputList (),
3711
+ " 'input_list'" )
3712
+ .failed ())
3713
+ return failure ();
3714
+
3715
+ // Condition block output must be a single element tensor with a single bool
3716
+ // value.
3717
+ auto condYield = cast<tosa::YieldOp>(getCondGraph ().front ().getTerminator ());
3718
+ if (condYield.getInputs ().size () != 1 )
3719
+ return emitOpError () << " require 'cond_graph' only have one result" ;
3720
+
3721
+ auto condOutType = condYield.getInputs ()[0 ].getType ();
3722
+ if (errorIfShapeNotSizeOne (*this , condOutType).failed ())
3723
+ return emitOpError () << " 'cond_graph' result must be a size 1 tensor, got "
3724
+ << condOutType;
3725
+
3726
+ if (!getElementTypeOrSelf (condOutType).isInteger (1 ))
3727
+ return emitOpError () << " 'cond_graph' result must be a boolean tensor, got "
3728
+ << condOutType;
3729
+
3730
+ return success ();
3731
+ }
3732
+
3604
3733
LogicalResult ReverseOp::verify () {
3605
3734
if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
3606
3735
/* outType = */ getOutput ().getType ())
0 commit comments