@@ -787,3 +787,291 @@ func.func @test_const_shape() -> !tosa.shape<4> {
787
787
%cst = tosa.const_shape {values = dense <1 > : tensor <4 xindex >} : () -> !tosa.shape <4 >
788
788
return %cst : !tosa.shape <4 >
789
789
}
790
+
791
+ // F8 support tests
792
+
793
+ // -----
794
+ // CHECK-LABEL: argmax_f8E5M2
795
+ func.func @test_argmax_f8E5M2 (%arg0: tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 > {
796
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 >
797
+ return %0 : tensor <12 x16 xi32 >
798
+ }
799
+
800
+ // -----
801
+ // CHECK-LABEL: avg_pool2d_f8E5M2
802
+ func.func @test_avg_pool2d_f8E5M2 (%arg0: tensor <1 x7 x7 x9 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 > {
803
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
804
+ %output_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
805
+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 >
806
+ return %0 : tensor <1 x7 x7 x9 xf8 E5 M2 >
807
+ }
808
+
809
+ // -----
810
+ // CHECK-LABEL: conv2d_f8E5M2
811
+ func.func @test_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <8 x1 x1 x4 xf8 E5 M2 >, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
812
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
813
+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
814
+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <8 x1 x1 x4 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
815
+ return %0 : tensor <1 x4 x4 x8 xf16 >
816
+ }
817
+
818
+ // -----
819
+ // CHECK-LABEL: conv3d_f8E5M2
820
+ func.func @test_conv3d_f8E5M2 (%arg0: tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, %arg1: tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 > {
821
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, tensor <34 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 >
822
+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
823
+ }
824
+
825
+ // -----
826
+ // CHECK-LABEL: depthwise_conv2d_f8E5M2
827
+ func.func @test_depthwise_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <1 x1 x4 x2 xf8 E5 M2 >, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 > {
828
+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <1 x1 x4 x2 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
829
+ return %0 : tensor <1 x4 x4 x8 xf16 >
830
+ }
831
+
832
+ // -----
833
+ // CHECK-LABEL: test_matmul_f8E5M2
834
+ func.func @test_matmul_f8E5M2 (%arg0: tensor <1 x14 x19 xf8 E5 M2 >, %arg1: tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 > {
835
+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E5 M2 >, tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 >
836
+ return %0 : tensor <1 x14 x28 xf16 >
837
+ }
838
+
839
+ // -----
840
+ // CHECK-LABEL: max_pool2d_f8E5M2
841
+ func.func @test_max_pool2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 > {
842
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 >
843
+ return %0 : tensor <1 x32 x32 x8 xf8 E5 M2 >
844
+ }
845
+
846
+ // -----
847
+
848
+ // CHECK-LABEL: transpose_conv2d_f8E5M2
849
+ func.func @test_transpose_conv2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >, %arg1: tensor <16 x1 x1 x8 xf8 E5 M2 >, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 > {
850
+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >, tensor <16 x1 x1 x8 xf8 E5 M2 >, tensor <16 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 >
851
+ return %0 : tensor <1 x32 x32 x16 xf16 >
852
+ }
853
+
854
+ // -----
855
+ // CHECK-LABEL: const_f8E5M2
856
+ func.func @test_const_f8E5M2 (%arg0 : index ) -> tensor <4 xf8 E5 M2 > {
857
+ %0 = " tosa.const" () {values = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E5 M2 >} : () -> tensor <4 xf8 E5 M2 >
858
+ return %0 : tensor <4 xf8 E5 M2 >
859
+ }
860
+
861
+ // -----
862
+ // CHECK-LABEL: cast_f8E5M2
863
+ func.func @test_cast_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 > {
864
+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 >
865
+ return %0 : tensor <13 x21 x3 xf16 >
866
+ }
867
+
868
+ // -----
869
+ // CHECK-LABEL: concat_f8E5M2
870
+ func.func @test_concat_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 > {
871
+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 >
872
+ return %0 : tensor <26 x21 x3 xf8 E5 M2 >
873
+ }
874
+
875
+ // -----
876
+ // CHECK-LABEL: pad_f8E5M2
877
+ func.func @test_pad_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
878
+ %padding = tosa.const_shape {values = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
879
+ %cst = " tosa.const" () { values = dense <-0.0 > : tensor <1 xf8 E5 M2 > } : () -> tensor <1 xf8 E5 M2 >
880
+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <6 >, tensor <1 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
881
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
882
+ }
883
+
884
+ // -----
885
+ // CHECK-LABEL: reshape_f8E5M2
886
+ func.func @test_reshape_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <1 x819 xf8 E5 M2 > {
887
+ %1 = tosa.const_shape {values = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
888
+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <2 >) -> tensor <1 x819 xf8 E5 M2 >
889
+ return %0 : tensor <1 x819 xf8 E5 M2 >
890
+ }
891
+
892
+ // -----
893
+ // CHECK-LABEL: reverse_f8E5M2
894
+ func.func @test_reverse_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
895
+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
896
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
897
+ }
898
+
899
+ // -----
900
+ // CHECK-LABEL: slice_f8E5M2
901
+ func.func @test_slice_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <4 x11 x1 xf8 E5 M2 > {
902
+ %0 = tosa.const_shape {values = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
903
+ %1 = tosa.const_shape {values = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
904
+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E5 M2 >
905
+ return %2 : tensor <4 x11 x1 xf8 E5 M2 >
906
+ }
907
+
908
+ // -----
909
+ // CHECK-LABEL: tile_f8E5M2
910
+ func.func @test_tile_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <39 x21 x6 xf8 E5 M2 > {
911
+ %cst = tosa.const_shape { values = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
912
+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E5 M2 >
913
+ return %0 : tensor <39 x21 x6 xf8 E5 M2 >
914
+ }
915
+
916
+ // -----
917
+ func.func @test_transpose_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 > {
918
+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 >
919
+ return %1 : tensor <3 x13 x21 xf8 E5 M2 >
920
+ }
921
+
922
+ // -----
923
+ // CHECK-LABEL: gather_f8E5M2
924
+ func.func @test_gather_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 > {
925
+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 >
926
+ return %0 : tensor <13 x26 x3 xf8 E5 M2 >
927
+ }
928
+
929
+ // -----
930
+ // CHECK-LABEL: scatter_f8E5M2
931
+ func.func @test_scatter_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
932
+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
933
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
934
+ }
935
+
936
+ // -----
937
+ // CHECK-LABEL: argmax_f8E4M3FN
938
+ func.func @test_argmax_f8E4M3FN (%arg0: tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 > {
939
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 >
940
+ return %0 : tensor <12 x16 xi32 >
941
+ }
942
+
943
+ // -----
944
+ // CHECK-LABEL: avg_pool2d_f8E4M3FN
945
+ func.func @test_avg_pool2d_f8E4M3FN (%arg0: tensor <1 x7 x7 x9 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN> {
946
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
947
+ %output_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
948
+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN>
949
+ return %0 : tensor <1 x7 x7 x9 xf8 E4 M3 FN>
950
+ }
951
+
952
+ // -----
953
+ // CHECK-LABEL: conv2d_f8E4M3FN
954
+ func.func @test_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <8 x1 x1 x4 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
955
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
956
+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
957
+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <8 x1 x1 x4 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
958
+ return %0 : tensor <1 x4 x4 x8 xf16 >
959
+ }
960
+
961
+ // -----
962
+ // CHECK-LABEL: conv3d_f8E4M3FN
963
+ func.func @test_conv3d_f8E4M3FN (%arg0: tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, %arg1: tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 > {
964
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, tensor <34 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 >
965
+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
966
+ }
967
+
968
+ // -----
969
+ // CHECK-LABEL: depthwise_conv2d_f8E4M3FN
970
+ func.func @test_depthwise_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <1 x1 x4 x2 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 > {
971
+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <1 x1 x4 x2 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
972
+ return %0 : tensor <1 x4 x4 x8 xf16 >
973
+ }
974
+
975
+ // -----
976
+ // CHECK-LABEL: matmul_f8E4M3FN
977
+ func.func @test_matmul_f8E4M3FN (%arg0: tensor <1 x14 x19 xf8 E4 M3 FN>, %arg1: tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 > {
978
+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E4 M3 FN>, tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 >
979
+ return %0 : tensor <1 x14 x28 xf16 >
980
+ }
981
+
982
+ // -----
983
+ // CHECK-LABEL: max_pool2d_f8E4M3FN
984
+ func.func @test_max_pool2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN> {
985
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN>
986
+ return %0 : tensor <1 x32 x32 x8 xf8 E4 M3 FN>
987
+ }
988
+
989
+ // -----
990
+ // CHECK-LABEL: transpose_conv2d_f8E4M3FN
991
+ func.func @test_transpose_conv2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>, %arg1: tensor <16 x1 x1 x8 xf8 E4 M3 FN>, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 > {
992
+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>, tensor <16 x1 x1 x8 xf8 E4 M3 FN>, tensor <16 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 >
993
+ return %0 : tensor <1 x32 x32 x16 xf16 >
994
+ }
995
+
996
+ // -----
997
+ // CHECK-LABEL: const_f8E4M3FN
998
+ func.func @test_const_f8E4M3FN (%arg0 : index ) -> tensor <4 xf8 E4 M3 FN> {
999
+ %0 = " tosa.const" () {values = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E4 M3 FN>} : () -> tensor <4 xf8 E4 M3 FN>
1000
+ return %0 : tensor <4 xf8 E4 M3 FN>
1001
+ }
1002
+
1003
+ // -----
1004
+ // CHECK-LABEL: cast_f8E4M3FN
1005
+ func.func @test_cast_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 > {
1006
+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 >
1007
+ return %0 : tensor <13 x21 x3 xf16 >
1008
+ }
1009
+
1010
+ // -----
1011
+ // CHECK-LABEL: concat_f8E4M3FN
1012
+ func.func @test_concat_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN> {
1013
+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN>
1014
+ return %0 : tensor <26 x21 x3 xf8 E4 M3 FN>
1015
+ }
1016
+
1017
+ // -----
1018
+ // CHECK-LABEL: pad_f8E4M3FN
1019
+ func.func @test_pad_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1020
+ %padding = tosa.const_shape {values = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
1021
+ %cst = " tosa.const" () { values = dense <-0.0 > : tensor <1 xf8 E4 M3 FN> } : () -> tensor <1 xf8 E4 M3 FN>
1022
+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <6 >, tensor <1 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1023
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1024
+ }
1025
+
1026
+ // -----
1027
+ // CHECK-LABEL: reshape_f8E4M3FN
1028
+ func.func @test_reshape_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <1 x819 xf8 E4 M3 FN> {
1029
+ %1 = tosa.const_shape {values = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
1030
+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <2 >) -> tensor <1 x819 xf8 E4 M3 FN>
1031
+ return %0 : tensor <1 x819 xf8 E4 M3 FN>
1032
+ }
1033
+
1034
+ // -----
1035
+ // CHECK-LABEL: reverse_f8E4M3FN
1036
+ func.func @test_reverse_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1037
+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1038
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1039
+ }
1040
+
1041
+ // -----
1042
+ // CHECK-LABEL: slice_f8E4M3FN
1043
+ func.func @test_slice_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <4 x11 x1 xf8 E4 M3 FN> {
1044
+ %0 = tosa.const_shape {values = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1045
+ %1 = tosa.const_shape {values = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1046
+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E4 M3 FN>
1047
+ return %2 : tensor <4 x11 x1 xf8 E4 M3 FN>
1048
+ }
1049
+
1050
+ // -----
1051
+ // CHECK-LABEL: tile_f8E4M3FN
1052
+ func.func @test_tile_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <39 x21 x6 xf8 E4 M3 FN> {
1053
+ %cst = tosa.const_shape { values = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
1054
+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E4 M3 FN>
1055
+ return %0 : tensor <39 x21 x6 xf8 E4 M3 FN>
1056
+ }
1057
+
1058
+ // -----
1059
+ // CHECK-LABEL: transpose_f8E4M3FN
1060
+ func.func @test_transpose_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN> {
1061
+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN>
1062
+ return %1 : tensor <3 x13 x21 xf8 E4 M3 FN>
1063
+ }
1064
+
1065
+ // -----
1066
+ // CHECK-LABEL: gather_f8E4M3FN
1067
+ func.func @test_gather_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN> {
1068
+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN>
1069
+ return %0 : tensor <13 x26 x3 xf8 E4 M3 FN>
1070
+ }
1071
+
1072
+ // -----
1073
+ // CHECK-LABEL: scatter_f8E4M3FN
1074
+ func.func @test_scatter_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1075
+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1076
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1077
+ }
0 commit comments