Skip to content

Commit 443e666

Browse files
tensorflower-gardenerTensorFlow MLIR Team
authored andcommitted
Support Per-channel quantization for DotGeneral
PiperOrigin-RevId: 617644231
1 parent f200cf4 commit 443e666

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,10 +1040,9 @@ FailureOr<bool> isDotLikeOpHybrid(DotLikeOp op) {
10401040
getElementTypeOrSelf(op.getResult()));
10411041

10421042
if (isLhsQuant && ((isRhsQuant && isResQuant) ||
1043-
(isa<mhlo::ConvolutionOp>(op) && isRhsQuantPerChannel &&
1044-
isResQuantPerChannel))) {
1045-
// For quantized ops, RHS and result must be both per-channel quantized.
1046-
// For Convolution, we also support per-channel quantized RHS/result.
1043+
(isRhsQuantPerChannel && isResQuantPerChannel))) {
1044+
// For quantized ops, RHS and result must be both per-channel quantized or
1045+
// both per-tensor quantized.
10471046
return false;
10481047
}
10491048
if (!isLhsQuant && !isLhsQuantPerChannel && isRhsQuant && !isResQuant &&

tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,57 @@ func.func @dot_general_multiple_dynamic_dims(
11661166

11671167
// -----
11681168

1169+
// CHECK-LABEL: func @dot_general_per_channel
1170+
func.func @dot_general_per_channel(
1171+
%arg0: tensor<?x2x!quant.uniform<i8:f32, 2.0:3>>,
1172+
%arg1: tensor<2x2x!quant.uniform<i8<-127:127>:f32:1, {3.0,4.0}>>
1173+
) -> tensor<?x2x!quant.uniform<i32:f32:1, {6.0,8.0}>> {
1174+
// CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general"
1175+
// CHECK-SAME: lhs_contracting_dimensions = [1]
1176+
// CHECK-SAME: rhs_contracting_dimensions = [0]>}
1177+
1178+
// Zero point offset contribution from RHS tensor * LHS ZP.
1179+
1180+
// CHECK: %[[RHS_I32:.*]] = mhlo.convert %arg1 : (tensor<2x2xi8>)
1181+
// CHECK-SAME: -> tensor<2x2xi32>
1182+
// CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
1183+
// CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]])
1184+
// CHECK-SAME: applies mhlo.add across dimensions = [0]
1185+
// CHECK-SAME: (tensor<2x2xi32>, tensor<i32>)
1186+
// CHECK-SAME: -> tensor<2xi32>
1187+
// CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor<i32>
1188+
// CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply
1189+
// CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] :
1190+
// CHECK-SAME: (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
1191+
1192+
// Calculate output dynamic dims.
1193+
// CHECK: %[[DIM_1_1:.*]] = "mhlo.get_dimension_size"(%[[DOT_RES]])
1194+
// CHECK-SAME: {dimension = 0 : i64}
1195+
// CHECK: %[[DIM_1_2:.*]] = mhlo.convert %[[DIM_1_1]] : (tensor<i32>) -> tensor<i64>
1196+
// CHECK: %[[DIM_1:.*]] = mhlo.reshape %[[DIM_1_2]] : (tensor<i64>) -> tensor<1xi64>
1197+
// CHECK: %[[DIM_2:.*]] = mhlo.constant dense<2> : tensor<1xi64>
1198+
// CHECK: %[[OUTPUT_DIMS:.*]] = "mhlo.concatenate"
1199+
// CHECK-SAME: %[[DIM_1]], %[[DIM_2]]
1200+
1201+
// CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"
1202+
// CHECK-SAME: (%[[RHS_ZP_CONTRIB]], %[[OUTPUT_DIMS]])
1203+
// CHECK-SAME: broadcast_dimensions = dense<1>
1204+
// CHECK-SAME: (tensor<2xi32>, tensor<2xi64>) -> tensor<?x2xi32>
1205+
// CHECK: %[[ZPS_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
1206+
// CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZPS_INIT]], %[[RHS_ZP_BCAST]]
1207+
// CHECK-SAME: (tensor<i32>, tensor<?x2xi32>) -> tensor<?x2xi32>
1208+
// CHECK: chlo.broadcast_add %[[DOT_RES]], %[[ZP_TOTAL_2]]
1209+
%0 = "mhlo.dot_general"(%arg0, %arg1) {
1210+
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1],
1211+
rhs_contracting_dimensions = [0]>} : (
1212+
tensor<?x2x!quant.uniform<i8:f32, 2.0:3>>,
1213+
tensor<2x2x!quant.uniform<i8<-127:127>:f32:1, {3.0,4.0}>>
1214+
) -> tensor<?x2x!quant.uniform<i32:f32:1, {6.0,8.0}>>
1215+
return %0 : tensor<?x2x!quant.uniform<i32:f32:1, {6.0,8.0}>>
1216+
}
1217+
1218+
// -----
1219+
11691220
// CHECK-LABEL: func @conv2d_dynamic
11701221
func.func @conv2d_dynamic(
11711222
%arg0: tensor<?x?x?x?x!quant.uniform<i8:f32, 2.000000e+00:4>>,

0 commit comments

Comments
 (0)