@@ -672,6 +672,63 @@ func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<2
672
672
673
673
// -----
674
674
675
+ // CHECK-LABEL: @conv2d_bias_broadcast_f32
676
+ func.func @conv2d_bias_broadcast_f32 (%input: tensor <1 x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >) -> () {
677
+ %bias = " tosa.const" () <{values = dense <4.20 > : tensor <28 xf32 >}> : () -> tensor <28 xf32 >
678
+ // CHECK-DAG: %[[CST:.+]] = arith.constant 4.200000e+00 : f32
679
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
680
+ // CHECK: %[[BIAS:.+]] = linalg.fill
681
+ // CHECK-SAME: ins(%[[CST]]
682
+ // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xf32>
683
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
684
+ // CHECK-SAME: outs(%[[BIAS]]
685
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
686
+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
687
+ %0 = tosa.conv2d %input , %weights , %bias , %input_zp , %weight_zp {acc_type = f32 , pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <1 x49 x42 x27 xf32 >, tensor <28 x3 x3 x27 xf32 >, tensor <28 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x45 x40 x28 xf32 >
688
+ return
689
+ }
690
+
691
+ // -----
692
+
693
+ // CHECK-LABEL: @conv2d_dynamic_batch_bias_broadcast_f32
694
+ // CHECK-SAME: (%[[INPUT:.+]]: tensor<?x49x42x27xf32>
695
+ func.func @conv2d_dynamic_batch_bias_broadcast_f32 (%input: tensor <?x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >) -> () {
696
+ %bias = " tosa.const" () <{values = dense <4.20 > : tensor <28 xf32 >}> : () -> tensor <28 xf32 >
697
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
698
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x49x42x27xf32>
699
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x45x40x28xf32>
700
+ // CHECK: %[[CST:.+]] = arith.constant 4.200000e+00 : f32
701
+ // CHECK: %[[BIAS:.+]] = linalg.fill
702
+ // CHECK-SAME: ins(%[[CST]]
703
+ // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<?x45x40x28xf32>
704
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
705
+ // CHECK-SAME: outs(%[[BIAS]]
706
+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
707
+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
708
+ %0 = tosa.conv2d %input , %weights , %bias , %input_zp , %weight_zp {acc_type = f32 , pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <?x49 x42 x27 xf32 >, tensor <28 x3 x3 x27 xf32 >, tensor <28 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <?x45 x40 x28 xf32 >
709
+ return
710
+ }
711
+
712
+ // -----
713
+
714
+ // CHECK-LABEL: @conv2d_bias_broadcast_i8_acc_i32
715
+ func.func @conv2d_bias_broadcast_i8_acc_i32 (%input: tensor <1 x49 x42 x27 xi8 >, %weights: tensor <28 x3 x3 x27 xi8 >) -> () {
716
+ %bias = " tosa.const" () <{values = dense <42 > : tensor <28 xi8 >}> : () -> tensor <28 xi8 >
717
+ // CHECK-DAG: %[[CST:.+]] = arith.constant 42 : i32
718
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
719
+ // CHECK: %[[BIAS:.+]] = linalg.fill
720
+ // CHECK-SAME: ins(%[[CST]]
721
+ // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xi32>
722
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
723
+ // CHECK-SAME: outs(%[[BIAS]]
724
+ %input_zp = " tosa.const" () <{values = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
725
+ %weight_zp = " tosa.const" () <{values = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
726
+ %0 = tosa.conv2d %input , %weights , %bias , %input_zp , %weight_zp {acc_type = i32 , pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <1 x49 x42 x27 xi8 >, tensor <28 x3 x3 x27 xi8 >, tensor <28 xi8 >, tensor <1 xi8 >, tensor <1 xi8 >) -> tensor <1 x45 x40 x28 xi32 >
727
+ return
728
+ }
729
+
730
+ // -----
731
+
675
732
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
676
733
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
677
734
0 commit comments