@@ -1166,6 +1166,57 @@ func.func @dot_general_multiple_dynamic_dims(
1166
1166
1167
1167
// -----
1168
1168
1169
+ // CHECK-LABEL: func @dot_general_per_channel
1170
+ func.func @dot_general_per_channel (
1171
+ %arg0: tensor <?x2 x!quant.uniform <i8 :f32 , 2.0 :3 >>,
1172
+ %arg1: tensor <2 x2 x!quant.uniform <i8 <-127 :127 >:f32 :1 , {3.0,4.0 }>>
1173
+ ) -> tensor <?x2 x!quant.uniform <i32 :f32 :1 , {6.0,8.0 }>> {
1174
+ // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general"
1175
+ // CHECK-SAME: lhs_contracting_dimensions = [1]
1176
+ // CHECK-SAME: rhs_contracting_dimensions = [0]>}
1177
+
1178
+ // Zero point offset contribution from RHS tensor * LHS ZP.
1179
+
1180
+ // CHECK: %[[RHS_I32:.*]] = mhlo.convert %arg1 : (tensor<2x2xi8>)
1181
+ // CHECK-SAME: -> tensor<2x2xi32>
1182
+ // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
1183
+ // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]])
1184
+ // CHECK-SAME: applies mhlo.add across dimensions = [0]
1185
+ // CHECK-SAME: (tensor<2x2xi32>, tensor<i32>)
1186
+ // CHECK-SAME: -> tensor<2xi32>
1187
+ // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor<i32>
1188
+ // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply
1189
+ // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] :
1190
+ // CHECK-SAME: (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
1191
+
1192
+ // Calculate output dynamic dims.
1193
+ // CHECK: %[[DIM_1_1:.*]] = "mhlo.get_dimension_size"(%[[DOT_RES]])
1194
+ // CHECK-SAME: {dimension = 0 : i64}
1195
+ // CHECK: %[[DIM_1_2:.*]] = mhlo.convert %[[DIM_1_1]] : (tensor<i32>) -> tensor<i64>
1196
+ // CHECK: %[[DIM_1:.*]] = mhlo.reshape %[[DIM_1_2]] : (tensor<i64>) -> tensor<1xi64>
1197
+ // CHECK: %[[DIM_2:.*]] = mhlo.constant dense<2> : tensor<1xi64>
1198
+ // CHECK: %[[OUTPUT_DIMS:.*]] = "mhlo.concatenate"
1199
+ // CHECK-SAME: %[[DIM_1]], %[[DIM_2]]
1200
+
1201
+ // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"
1202
+ // CHECK-SAME: (%[[RHS_ZP_CONTRIB]], %[[OUTPUT_DIMS]])
1203
+ // CHECK-SAME: broadcast_dimensions = dense<1>
1204
+ // CHECK-SAME: (tensor<2xi32>, tensor<2xi64>) -> tensor<?x2xi32>
1205
+ // CHECK: %[[ZPS_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
1206
+ // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZPS_INIT]], %[[RHS_ZP_BCAST]]
1207
+ // CHECK-SAME: (tensor<i32>, tensor<?x2xi32>) -> tensor<?x2xi32>
1208
+ // CHECK: chlo.broadcast_add %[[DOT_RES]], %[[ZP_TOTAL_2]]
1209
+ %0 = " mhlo.dot_general" (%arg0 , %arg1 ) {
1210
+ dot_dimension_numbers = #mhlo.dot <lhs_contracting_dimensions = [1 ],
1211
+ rhs_contracting_dimensions = [0 ]>} : (
1212
+ tensor <?x2 x!quant.uniform <i8 :f32 , 2.0 :3 >>,
1213
+ tensor <2 x2 x!quant.uniform <i8 <-127 :127 >:f32 :1 , {3.0,4.0 }>>
1214
+ ) -> tensor <?x2 x!quant.uniform <i32 :f32 :1 , {6.0,8.0 }>>
1215
+ return %0 : tensor <?x2 x!quant.uniform <i32 :f32 :1 , {6.0,8.0 }>>
1216
+ }
1217
+
1218
+ // -----
1219
+
1169
1220
// CHECK-LABEL: func @conv2d_dynamic
1170
1221
func.func @conv2d_dynamic (
1171
1222
%arg0: tensor <?x?x?x?x!quant.uniform <i8 :f32 , 2.000000e+00 :4 >>,
0 commit comments