@@ -76,3 +76,25 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
76
76
%0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %zp , %zp {acc_type = f32 , pad = array<i64 : 1 , 1 , 1 , 1 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 1 , 1 >} : (tensor <4 x10 x10 x2 xf32 >, tensor <1 x1 x2 x3 xf32 >, tensor <6 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <4 x12 x12 x6 xf32 >
77
77
return %0 : tensor <4 x12 x12 x6 xf32 >
78
78
}
79
+
80
+ // -----
81
+
82
+ // Decompose only support integer or float types.
83
+
84
+ // CHECK-LABEL: @depthwise_conv2d_quant_type
85
+ func.func @depthwise_conv2d_quant_type (%arg0: tensor <4 x10 x10 x2 x!quant.uniform <i8 :f32 , 0.015684768557548523 >>, %arg1: tensor <1 x1 x2 x3 x!quant.uniform <i8 <-127 :127 >:f32 , 0.015680249780416489 >>, %arg2: tensor <6 xi32 >) -> tensor <4 x10 x10 x6 x!quant.uniform <i32 :f32 , 0.078431375324726104 >> {
86
+ %0 = " tosa.const" () <{value = dense <7 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
87
+ %1 = " tosa.const" () <{value = dense <11 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
88
+ // CHECK: tosa.depthwise_conv2d
89
+ %2 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %0 , %1 {acc_type = i32 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <4 x10 x10 x2 x!quant.uniform <i8 :f32 , 0.015684768557548523 >>, tensor <1 x1 x2 x3 x!quant.uniform <i8 <-127 :127 >:f32 , 0.015680249780416489 >>, tensor <6 xi32 >, tensor <1 xi8 >, tensor <1 xi8 >) -> tensor <4 x10 x10 x6 x!quant.uniform <i32 :f32 , 0.078431375324726104 >>
90
+ return %2 : tensor <4 x10 x10 x6 x!quant.uniform <i32 :f32 , 0.078431375324726104 >>
91
+ }
92
+
93
+ // -----
94
+
95
+ // CHECK-LABEL: @depthwise_conv2d_no_const_zero_point
96
+ func.func @depthwise_conv2d_no_const_zero_point (%arg0: tensor <4 x10 x10 x2 xi8 >, %arg1: tensor <1 x1 x2 x3 xi8 >, %arg2: tensor <6 xi32 >, %arg3: tensor <1 xi8 >, %arg4: tensor <1 xi8 >) -> tensor <4 x10 x10 x6 xi32 > {
97
+ // CHECK: tosa.depthwise_conv2d
98
+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = i32 , pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 1 , 1 >} : (tensor <4 x10 x10 x2 xi8 >, tensor <1 x1 x2 x3 xi8 >, tensor <6 xi32 >, tensor <1 xi8 >, tensor <1 xi8 >) -> tensor <4 x10 x10 x6 xi32 >
99
+ return %0 : tensor <4 x10 x10 x6 xi32 >
100
+ }
0 commit comments