Skip to content

Commit ca582b1

Browse files
Jerry-GeTai78641
andauthored
[mlir][tosa] Add FP8 lit tests (#127730)
Add FP8 lit tests to the following operators: ARGMAX AVGPOOL CONV2D CONV3D DEPTHWISE_CONV2D MATMUL MAX_POOL2D TRANSPOSE_CONV2D CONST CAST CONCAT PAD RESHAPE REVERSE SLICE TILE TRANSPOSE GATHER SCATTER Signed-off-by: Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]> Co-authored-by: Tai Ly <[email protected]>
1 parent ca1833b commit ca582b1

File tree

2 files changed

+289
-8
lines changed

2 files changed

+289
-8
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -522,14 +522,7 @@ LogicalResult tosa::AvgPool2dOp::verify() {
522522
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
523523
return failure();
524524

525-
if ((inputETy.isF32() && resultETy.isF32()) ||
526-
(inputETy.isF16() && resultETy.isF16()) ||
527-
(inputETy.isBF16() && resultETy.isBF16()) ||
528-
(inputETy.isInteger(8) && resultETy.isInteger(8)) ||
529-
(inputETy.isInteger(16) && resultETy.isInteger(16)))
530-
return success();
531-
532-
return emitOpError("input/output element types are incompatible.");
525+
return success();
533526
}
534527

535528
LogicalResult tosa::ClampOp::verify() {

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,3 +787,291 @@ func.func @test_const_shape() -> !tosa.shape<4> {
787787
%cst = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
788788
return %cst : !tosa.shape<4>
789789
}
790+
791+
// F8 support tests
792+
793+
// -----
794+
// CHECK-LABEL: argmax_f8E5M2
795+
func.func @test_argmax_f8E5M2(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32> {
796+
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi32>
797+
return %0 : tensor<12x16xi32>
798+
}
799+
800+
// -----
801+
// CHECK-LABEL: avg_pool2d_f8E5M2
802+
func.func @test_avg_pool2d_f8E5M2(%arg0: tensor<1x7x7x9xf8E5M2>) -> tensor<1x7x7x9xf8E5M2> {
803+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
804+
%output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
805+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E5M2>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x7x7x9xf8E5M2>
806+
return %0 : tensor<1x7x7x9xf8E5M2>
807+
}
808+
809+
// -----
810+
// CHECK-LABEL: conv2d_f8E5M2
811+
func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1x4xf8E5M2>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
812+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
813+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
814+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E5M2>, tensor<8x1x1x4xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
815+
return %0 : tensor<1x4x4x8xf16>
816+
}
817+
818+
// -----
819+
// CHECK-LABEL: conv3d_f8E5M2
820+
func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> {
821+
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16>
822+
return %0 : tensor<1x4x8x21x34xf16>
823+
}
824+
825+
// -----
826+
// CHECK-LABEL: depthwise_conv2d_f8E5M2
827+
func.func @test_depthwise_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<1x1x4x2xf8E5M2>, %arg2: tensor<8xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16> {
828+
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E5M2>, tensor<1x1x4x2xf8E5M2>, tensor<8xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x4x8xf16>
829+
return %0 : tensor<1x4x4x8xf16>
830+
}
831+
832+
// -----
833+
// CHECK-LABEL: test_matmul_f8E5M2
834+
func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
835+
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16>
836+
return %0 : tensor<1x14x28xf16>
837+
}
838+
839+
// -----
840+
// CHECK-LABEL: max_pool2d_f8E5M2
841+
func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {
842+
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2>
843+
return %0 : tensor<1x32x32x8xf8E5M2>
844+
}
845+
846+
// -----
847+
848+
// CHECK-LABEL: transpose_conv2d_f8E5M2
849+
func.func @test_transpose_conv2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>, %arg1: tensor<16x1x1x8xf8E5M2>, %arg2: tensor<16xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16> {
850+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E5M2>, tensor<16x1x1x8xf8E5M2>, tensor<16xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x32x32x16xf16>
851+
return %0 : tensor<1x32x32x16xf16>
852+
}
853+
854+
// -----
855+
// CHECK-LABEL: const_f8E5M2
856+
func.func @test_const_f8E5M2(%arg0 : index) -> tensor<4xf8E5M2> {
857+
%0 = "tosa.const"() {values = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E5M2>} : () -> tensor<4xf8E5M2>
858+
return %0 : tensor<4xf8E5M2>
859+
}
860+
861+
// -----
862+
// CHECK-LABEL: cast_f8E5M2
863+
func.func @test_cast_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16> {
864+
%0 = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16>
865+
return %0 : tensor<13x21x3xf16>
866+
}
867+
868+
// -----
869+
// CHECK-LABEL: concat_f8E5M2
870+
func.func @test_concat_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2> {
871+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>, tensor<13x21x3xf8E5M2>) -> tensor<26x21x3xf8E5M2>
872+
return %0 : tensor<26x21x3xf8E5M2>
873+
}
874+
875+
// -----
876+
// CHECK-LABEL: pad_f8E5M2
877+
func.func @test_pad_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
878+
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
879+
%cst = "tosa.const"() { values = dense<-0.0> : tensor<1xf8E5M2> } : () -> tensor<1xf8E5M2>
880+
%0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E5M2>, !tosa.shape<6>, tensor<1xf8E5M2>) -> tensor<13x21x3xf8E5M2>
881+
return %0 : tensor<13x21x3xf8E5M2>
882+
}
883+
884+
// -----
885+
// CHECK-LABEL: reshape_f8E5M2
886+
func.func @test_reshape_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<1x819xf8E5M2> {
887+
%1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
888+
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<2>) -> tensor<1x819xf8E5M2>
889+
return %0 : tensor<1x819xf8E5M2>
890+
}
891+
892+
// -----
893+
// CHECK-LABEL: reverse_f8E5M2
894+
func.func @test_reverse_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
895+
%0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
896+
return %0 : tensor<13x21x3xf8E5M2>
897+
}
898+
899+
// -----
900+
// CHECK-LABEL: slice_f8E5M2
901+
func.func @test_slice_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<4x11x1xf8E5M2> {
902+
%0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
903+
%1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
904+
%2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E5M2>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E5M2>
905+
return %2 : tensor<4x11x1xf8E5M2>
906+
}
907+
908+
// -----
909+
// CHECK-LABEL: tile_f8E5M2
910+
func.func @test_tile_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<39x21x6xf8E5M2> {
911+
%cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
912+
%0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E5M2>, !tosa.shape<3>) -> tensor<39x21x6xf8E5M2>
913+
return %0 : tensor<39x21x6xf8E5M2>
914+
}
915+
916+
// -----
917+
func.func @test_transpose_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<3x13x21xf8E5M2> {
918+
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf8E5M2>) -> tensor<3x13x21xf8E5M2>
919+
return %1 : tensor<3x13x21xf8E5M2>
920+
}
921+
922+
// -----
923+
// CHECK-LABEL: gather_f8E5M2
924+
func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2> {
925+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>) -> tensor<13x26x3xf8E5M2>
926+
return %0 : tensor<13x26x3xf8E5M2>
927+
}
928+
929+
// -----
930+
// CHECK-LABEL: scatter_f8E5M2
931+
func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
932+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
933+
return %0 : tensor<13x21x3xf8E5M2>
934+
}
935+
936+
// -----
937+
// CHECK-LABEL: argmax_f8E4M3FN
938+
func.func @test_argmax_f8E4M3FN(%arg0: tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32> {
939+
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E4M3FN>) -> tensor<12x16xi32>
940+
return %0 : tensor<12x16xi32>
941+
}
942+
943+
// -----
944+
// CHECK-LABEL: avg_pool2d_f8E4M3FN
945+
func.func @test_avg_pool2d_f8E4M3FN(%arg0: tensor<1x7x7x9xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN> {
946+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
947+
%output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
948+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x7x7x9xf8E4M3FN>
949+
return %0 : tensor<1x7x7x9xf8E4M3FN>
950+
}
951+
952+
// -----
953+
// CHECK-LABEL: conv2d_f8E4M3FN
954+
func.func @test_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<8x1x1x4xf8E4M3FN>, %arg2: tensor<8xf16>) -> tensor<1x4x4x8xf16> {
955+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
956+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
957+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf8E4M3FN>, tensor<8x1x1x4xf8E4M3FN>, tensor<8xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16>
958+
return %0 : tensor<1x4x4x8xf16>
959+
}
960+
961+
// -----
962+
// CHECK-LABEL: conv3d_f8E4M3FN
963+
func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> {
964+
%0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<34xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16>
965+
return %0 : tensor<1x4x8x21x34xf16>
966+
}
967+
968+
// -----
969+
// CHECK-LABEL: depthwise_conv2d_f8E4M3FN
970+
func.func @test_depthwise_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<1x1x4x2xf8E4M3FN>, %arg2: tensor<8xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16> {
971+
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf8E4M3FN>, tensor<1x1x4x2xf8E4M3FN>, tensor<8xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x4x8xf16>
972+
return %0 : tensor<1x4x4x8xf16>
973+
}
974+
975+
// -----
976+
// CHECK-LABEL: matmul_f8E4M3FN
977+
func.func @test_matmul_f8E4M3FN(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
978+
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16>
979+
return %0 : tensor<1x14x28xf16>
980+
}
981+
982+
// -----
983+
// CHECK-LABEL: max_pool2d_f8E4M3FN
984+
func.func @test_max_pool2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN> {
985+
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>) -> tensor<1x32x32x8xf8E4M3FN>
986+
return %0 : tensor<1x32x32x8xf8E4M3FN>
987+
}
988+
989+
// -----
990+
// CHECK-LABEL: transpose_conv2d_f8E4M3FN
991+
func.func @test_transpose_conv2d_f8E4M3FN(%arg0: tensor<1x32x32x8xf8E4M3FN>, %arg1: tensor<16x1x1x8xf8E4M3FN>, %arg2: tensor<16xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x32x32x16xf16> {
992+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf8E4M3FN>, tensor<16x1x1x8xf8E4M3FN>, tensor<16xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x32x32x16xf16>
993+
return %0 : tensor<1x32x32x16xf16>
994+
}
995+
996+
// -----
997+
// CHECK-LABEL: const_f8E4M3FN
998+
func.func @test_const_f8E4M3FN(%arg0 : index) -> tensor<4xf8E4M3FN> {
999+
%0 = "tosa.const"() {values = dense<[3.0, -0.0, -1.0, 2.0]> : tensor<4xf8E4M3FN>} : () -> tensor<4xf8E4M3FN>
1000+
return %0 : tensor<4xf8E4M3FN>
1001+
}
1002+
1003+
// -----
1004+
// CHECK-LABEL: cast_f8E4M3FN
1005+
func.func @test_cast_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16> {
1006+
%0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
1007+
return %0 : tensor<13x21x3xf16>
1008+
}
1009+
1010+
// -----
1011+
// CHECK-LABEL: concat_f8E4M3FN
1012+
func.func @test_concat_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN> {
1013+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>, tensor<13x21x3xf8E4M3FN>) -> tensor<26x21x3xf8E4M3FN>
1014+
return %0 : tensor<26x21x3xf8E4M3FN>
1015+
}
1016+
1017+
// -----
1018+
// CHECK-LABEL: pad_f8E4M3FN
1019+
func.func @test_pad_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1020+
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
1021+
%cst = "tosa.const"() { values = dense<-0.0> : tensor<1xf8E4M3FN> } : () -> tensor<1xf8E4M3FN>
1022+
%0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<1xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1023+
return %0 : tensor<13x21x3xf8E4M3FN>
1024+
}
1025+
1026+
// -----
1027+
// CHECK-LABEL: reshape_f8E4M3FN
1028+
func.func @test_reshape_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1x819xf8E4M3FN> {
1029+
%1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
1030+
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<2>) -> tensor<1x819xf8E4M3FN>
1031+
return %0 : tensor<1x819xf8E4M3FN>
1032+
}
1033+
1034+
// -----
1035+
// CHECK-LABEL: reverse_f8E4M3FN
1036+
func.func @test_reverse_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1037+
%0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1038+
return %0 : tensor<13x21x3xf8E4M3FN>
1039+
}
1040+
1041+
// -----
1042+
// CHECK-LABEL: slice_f8E4M3FN
1043+
func.func @test_slice_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<4x11x1xf8E4M3FN> {
1044+
%0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1045+
%1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
1046+
%2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf8E4M3FN>
1047+
return %2 : tensor<4x11x1xf8E4M3FN>
1048+
}
1049+
1050+
// -----
1051+
// CHECK-LABEL: tile_f8E4M3FN
1052+
func.func @test_tile_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<39x21x6xf8E4M3FN> {
1053+
%cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
1054+
%0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf8E4M3FN>, !tosa.shape<3>) -> tensor<39x21x6xf8E4M3FN>
1055+
return %0 : tensor<39x21x6xf8E4M3FN>
1056+
}
1057+
1058+
// -----
1059+
// CHECK-LABEL: transpose_f8E4M3FN
1060+
func.func @test_transpose_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<3x13x21xf8E4M3FN> {
1061+
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf8E4M3FN>) -> tensor<3x13x21xf8E4M3FN>
1062+
return %1 : tensor<3x13x21xf8E4M3FN>
1063+
}
1064+
1065+
// -----
1066+
// CHECK-LABEL: gather_f8E4M3FN
1067+
func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN> {
1068+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>) -> tensor<13x26x3xf8E4M3FN>
1069+
return %0 : tensor<13x26x3xf8E4M3FN>
1070+
}
1071+
1072+
// -----
1073+
// CHECK-LABEL: scatter_f8E4M3FN
1074+
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1075+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1076+
return %0 : tensor<13x21x3xf8E4M3FN>
1077+
}

0 commit comments

Comments
 (0)