@@ -345,6 +345,12 @@ static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
345
345
}
346
346
}
347
347
348
+ static void getF64Values (ArrayAttr arrayAttr, SmallVector<double > &values) {
349
+ for (auto it : arrayAttr) {
350
+ values.push_back (it.cast <FloatAttr>().getValueAsDouble ());
351
+ }
352
+ }
353
+
348
354
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents (
349
355
MLIRContext *context, ::llvm::Optional<Location> location,
350
356
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -386,13 +392,13 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
386
392
387
393
// Copy the Operand's rank.
388
394
if (!hasRankedInput)
389
- outputShape.resize (operandTy.getRank (), - 1 );
395
+ outputShape.resize (operandTy.getRank (), ShapedType:: kDynamicSize );
390
396
391
397
// Copy shapes until the dim is non-dynamic.
392
398
for (int i = 0 , s = operandTy.getRank (); i < s; i++) {
393
399
if (i == axis || operandTy.isDynamicDim (i))
394
400
continue ;
395
- if (outputShape[i] == - 1 )
401
+ if (outputShape[i] == ShapedType:: kDynamicSize )
396
402
outputShape[i] = operandTy.getDimSize (i);
397
403
if (outputShape[i] != operandTy.getDimSize (i))
398
404
return failure ();
@@ -414,7 +420,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
414
420
// We need to know the length of the concatenation axis of all inputs to
415
421
// determine the dimension size of the output shape.
416
422
if (!operandTy.hasRank () || operandTy.isDynamicDim (axis)) {
417
- concatDimSize = - 1 ;
423
+ concatDimSize = ShapedType:: kDynamicSize ;
418
424
break ;
419
425
}
420
426
@@ -437,7 +443,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
437
443
438
444
// All shapes are dynamic.
439
445
SmallVector<int64_t > outShape;
440
- outShape.resize (2 , - 1 );
446
+ outShape.resize (2 , ShapedType:: kDynamicSize );
441
447
442
448
if (inputTy.hasRank ()) {
443
449
outShape[0 ] = inputTy.getDimSize (0 );
@@ -448,7 +454,8 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
448
454
}
449
455
450
456
if (biasTy.hasRank ()) {
451
- outShape[1 ] = outShape[1 ] == -1 ? biasTy.getDimSize (0 ) : outShape[1 ];
457
+ outShape[1 ] = outShape[1 ] == ShapedType::kDynamicSize ? biasTy.getDimSize (0 )
458
+ : outShape[1 ];
452
459
}
453
460
454
461
inferredReturnShapes.push_back (ShapedTypeComponents (outShape));
@@ -464,15 +471,16 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
464
471
465
472
// All shapes are dynamic.
466
473
SmallVector<int64_t > outShape;
467
- outShape.resize (3 , - 1 );
474
+ outShape.resize (3 , ShapedType:: kDynamicSize );
468
475
469
476
if (lhsTy.hasRank ()) {
470
477
outShape[0 ] = lhsTy.getDimSize (0 );
471
478
outShape[1 ] = lhsTy.getDimSize (1 );
472
479
}
473
480
474
481
if (rhsTy.hasRank ()) {
475
- outShape[0 ] = outShape[0 ] == -1 ? rhsTy.getDimSize (0 ) : outShape[0 ];
482
+ outShape[0 ] = outShape[0 ] == ShapedType::kDynamicSize ? rhsTy.getDimSize (0 )
483
+ : outShape[0 ];
476
484
outShape[2 ] = rhsTy.getDimSize (2 );
477
485
}
478
486
@@ -503,15 +511,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
503
511
return success ();
504
512
}
505
513
506
- outputShape.resize (paddingTy.getDimSize (0 ), - 1 );
514
+ outputShape.resize (paddingTy.getDimSize (0 ), ShapedType:: kDynamicSize );
507
515
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
508
516
return success ();
509
517
}
510
518
511
519
DenseIntElementsAttr paddings;
512
520
// If the paddings value is not a constant, all dimensions must be dynamic.
513
521
if (!matchPattern (operands[1 ], m_Constant (&paddings))) {
514
- outputShape.resize (inputTy.getRank (), - 1 );
522
+ outputShape.resize (inputTy.getRank (), ShapedType:: kDynamicSize );
515
523
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
516
524
return success ();
517
525
}
@@ -524,7 +532,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
524
532
outputShape.reserve (inputTy.getRank ());
525
533
for (int i = 0 , s = inputTy.getRank (); i < s; i++) {
526
534
if (inputTy.isDynamicDim (i)) {
527
- outputShape.push_back (- 1 );
535
+ outputShape.push_back (ShapedType:: kDynamicSize );
528
536
continue ;
529
537
}
530
538
@@ -574,7 +582,7 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
574
582
ShapedType inputTy = operands[0 ].getType ().cast <ShapedType>();
575
583
SmallVector<int64_t > outputShape;
576
584
if (!inputTy.hasRank ()) {
577
- outputShape.resize (multiples.size (), - 1 );
585
+ outputShape.resize (multiples.size (), ShapedType:: kDynamicSize );
578
586
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
579
587
return success ();
580
588
}
@@ -590,7 +598,7 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
590
598
outputShape.reserve (multiples.size ());
591
599
for (int i = 0 , s = inputTy.getRank (); i < s; i++) {
592
600
int dim = inputTy.getDimSize (i);
593
- if (dim != - 1 )
601
+ if (dim != ShapedType:: kDynamicSize )
594
602
dim *= multipleValues[i];
595
603
outputShape.push_back (dim);
596
604
}
@@ -622,14 +630,14 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
622
630
int64_t numElements = type.getNumElements ();
623
631
int64_t staticMul = 1 ;
624
632
for (auto val : newShapeValue) {
625
- if (val != - 1 ) {
633
+ if (val != ShapedType:: kDynamicSize ) {
626
634
staticMul *= val;
627
635
}
628
636
}
629
637
630
638
// Determine the length of the dynamic dimension.
631
639
for (auto &val : newShapeValue) {
632
- if (val == - 1 )
640
+ if (val == ShapedType:: kDynamicSize )
633
641
val = numElements / staticMul;
634
642
}
635
643
@@ -655,7 +663,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
655
663
// can determine the output rank.
656
664
SmallVector<int64_t > outputShape;
657
665
if (!inputTy.hasRank ()) {
658
- outputShape.resize (permsTy.getDimSize (0 ), - 1 );
666
+ outputShape.resize (permsTy.getDimSize (0 ), ShapedType:: kDynamicSize );
659
667
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
660
668
return success ();
661
669
}
@@ -684,7 +692,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
684
692
}
685
693
686
694
DenseIntElementsAttr perms;
687
- outputShape.resize (inputTy.getRank (), - 1 );
695
+ outputShape.resize (inputTy.getRank (), ShapedType:: kDynamicSize );
688
696
// If the permuations are a constant we can directly determine the output
689
697
// shape.
690
698
if (matchPattern (operands[1 ], m_Constant (&perms))) {
@@ -708,30 +716,100 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
708
716
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
709
717
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
710
718
llvm::SmallVector<int64_t > outputShape;
711
- outputShape.resize (3 , - 1 );
719
+ outputShape.resize (3 , ShapedType:: kDynamicSize );
712
720
713
721
if (auto ty = operands[0 ].getType ().dyn_cast <RankedTensorType>()) {
714
722
outputShape[0 ] = ty.getDimSize (0 );
715
723
outputShape[2 ] = ty.getDimSize (2 );
716
724
}
717
725
718
726
if (auto ty = operands[1 ].getType ().dyn_cast <RankedTensorType>()) {
719
- if (outputShape[0 ] == - 1 )
727
+ if (outputShape[0 ] == ShapedType:: kDynamicSize )
720
728
outputShape[0 ] = ty.getDimSize (0 );
721
- if (outputShape[1 ] == - 1 )
729
+ if (outputShape[1 ] == ShapedType:: kDynamicSize )
722
730
outputShape[1 ] = ty.getDimSize (1 );
723
731
}
724
732
725
733
inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
726
734
return success ();
727
735
}
728
736
737
+ LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
738
+ MLIRContext *context, ::llvm::Optional<Location> location,
739
+ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
740
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
741
+ llvm::SmallVector<int64_t , 4 > outputShape;
742
+ outputShape.resize (4 , ShapedType::kDynamicSize );
743
+
744
+ int32_t inHeight = ShapedType::kDynamicSize ;
745
+ int32_t inWidth = ShapedType::kDynamicSize ;
746
+
747
+ if (auto ty = operands[0 ].getType ().dyn_cast <RankedTensorType>()) {
748
+ outputShape[0 ] = ty.getDimSize (0 );
749
+ outputShape[3 ] = ty.getDimSize (3 );
750
+
751
+ inHeight = ty.getDimSize (1 );
752
+ inWidth = ty.getDimSize (2 );
753
+ }
754
+
755
+ int32_t shift =
756
+ attributes.get (" shift" ).cast <IntegerAttr>().getValue ().getSExtValue ();
757
+ llvm::SmallVector<int64_t > newShape;
758
+ getI64Values (attributes.get (" output_size" ).cast <ArrayAttr>(), newShape);
759
+ outputShape[1 ] = newShape[0 ];
760
+ outputShape[2 ] = newShape[1 ];
761
+
762
+ llvm::SmallVector<int64_t > strideInt;
763
+ llvm::SmallVector<int64_t > offsetInt;
764
+ llvm::SmallVector<double > strideFp;
765
+ llvm::SmallVector<double > offsetFp;
766
+ getI64Values (attributes.get (" offset" ).cast <ArrayAttr>(), offsetInt);
767
+ getF64Values (attributes.get (" offset_fp" ).cast <ArrayAttr>(), offsetFp);
768
+ getI64Values (attributes.get (" stride" ).cast <ArrayAttr>(), strideInt);
769
+ getF64Values (attributes.get (" stride_fp" ).cast <ArrayAttr>(), strideFp);
770
+
771
+ // If we have a 0 zero in integers we know that the resize indexing needs to
772
+ // be performed in floating point. Use the floating point varient to compute
773
+ // the resize shape.
774
+ bool fpMode = strideInt[0 ] == 0 ;
775
+
776
+ // We can compute the output shape if attribute specifies unknown dimensions
777
+ // based on the offset and stride. If we perfectly line up to the last index
778
+ // we need to round up the size to include it.
779
+ if (outputShape[1 ] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
780
+ float sizeFp = (inHeight - offsetFp[0 ] - 1 ) / strideFp[0 ];
781
+ float round = std::floor (sizeFp) == sizeFp ? 1 : 0 ;
782
+ outputShape[1 ] = std::ceil (sizeFp) + round;
783
+ }
784
+
785
+ if (outputShape[2 ] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
786
+ float sizeFp = (inWidth - offsetFp[1 ] - 1 ) / strideFp[1 ];
787
+ float round = std::floor (sizeFp) == sizeFp ? 1 : 0 ;
788
+ outputShape[2 ] = std::ceil (sizeFp) + round;
789
+ }
790
+
791
+ if (outputShape[1 ] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
792
+ int64_t size = (inHeight - 1 );
793
+ size = ((size << shift) - offsetInt[0 ]) / strideInt[0 ];
794
+ outputShape[1 ] = size + 1 ;
795
+ }
796
+
797
+ if (outputShape[2 ] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
798
+ int64_t size = (inWidth - 1 );
799
+ size = ((size << shift) - offsetInt[1 ]) / strideInt[1 ];
800
+ outputShape[2 ] = size + 1 ;
801
+ }
802
+
803
+ inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
804
+ return success ();
805
+ }
806
+
729
807
LogicalResult tosa::ScatterOp::inferReturnTypeComponents (
730
808
MLIRContext *context, ::llvm::Optional<Location> location,
731
809
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
732
810
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
733
811
llvm::SmallVector<int64_t > outputShape;
734
- outputShape.resize (3 , - 1 );
812
+ outputShape.resize (3 , ShapedType:: kDynamicSize );
735
813
736
814
if (auto ty = operands[0 ].getType ().dyn_cast <RankedTensorType>()) {
737
815
outputShape[0 ] = ty.getDimSize (0 );
@@ -740,14 +818,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
740
818
}
741
819
742
820
if (auto ty = operands[1 ].getType ().dyn_cast <RankedTensorType>()) {
743
- if (outputShape[0 ] == - 1 )
821
+ if (outputShape[0 ] == ShapedType:: kDynamicSize )
744
822
outputShape[0 ] = ty.getDimSize (0 );
745
823
}
746
824
747
825
if (auto ty = operands[2 ].getType ().dyn_cast <RankedTensorType>()) {
748
- if (outputShape[0 ] == - 1 )
826
+ if (outputShape[0 ] == ShapedType:: kDynamicSize )
749
827
outputShape[0 ] = ty.getDimSize (0 );
750
- if (outputShape[2 ] == - 1 )
828
+ if (outputShape[2 ] == ShapedType:: kDynamicSize )
751
829
outputShape[2 ] = ty.getDimSize (2 );
752
830
}
753
831
@@ -859,6 +937,7 @@ NARY_SHAPE_INFER(tosa::BitwiseAndOp)
859
937
NARY_SHAPE_INFER(tosa::BitwiseOrOp)
860
938
NARY_SHAPE_INFER(tosa::BitwiseXorOp)
861
939
NARY_SHAPE_INFER(tosa::BitwiseNotOp)
940
+ NARY_SHAPE_INFER(tosa::CastOp)
862
941
NARY_SHAPE_INFER(tosa::CeilOp)
863
942
NARY_SHAPE_INFER(tosa::ClampOp)
864
943
NARY_SHAPE_INFER(tosa::ClzOp)
@@ -868,6 +947,7 @@ NARY_SHAPE_INFER(tosa::ExpOp)
868
947
NARY_SHAPE_INFER(tosa::FloorOp)
869
948
NARY_SHAPE_INFER(tosa::GreaterEqualOp)
870
949
NARY_SHAPE_INFER(tosa::GreaterOp)
950
+ NARY_SHAPE_INFER(tosa::IdentityOp)
871
951
NARY_SHAPE_INFER(tosa::LogOp)
872
952
NARY_SHAPE_INFER(tosa::LogicalAndOp)
873
953
NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
@@ -882,6 +962,7 @@ NARY_SHAPE_INFER(tosa::NegateOp)
882
962
NARY_SHAPE_INFER(tosa::PowOp)
883
963
NARY_SHAPE_INFER(tosa::ReciprocalOp)
884
964
NARY_SHAPE_INFER(tosa::ReluNOp)
965
+ NARY_SHAPE_INFER(tosa::RescaleOp)
885
966
NARY_SHAPE_INFER(tosa::ReverseOp)
886
967
NARY_SHAPE_INFER(tosa::RsqrtOp)
887
968
NARY_SHAPE_INFER(tosa::SelectOp)
0 commit comments