Skip to content

Commit a1f4656

Browse files
NatashaKnkrsuderman
authored andcommitted
[mlir][tosa] Add quantized and unquantized versions for tosa.depthwise_conv2d lowering
Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D107855
1 parent d39ebda commit a1f4656

File tree

4 files changed

+274
-9
lines changed

4 files changed

+274
-9
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 196 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,202 @@ structured_op: !LinalgStructuredOpConfig
905905
- !ScalarExpression
906906
scalar_arg: K
907907
--- !LinalgOpConfig
908+
metadata: !LinalgOpMetadata
909+
name: depthwise_conv_2D_nchw
910+
cpp_class_name: DepthwiseConv2DNchwOp
911+
doc: |-
912+
Performs depth-wise 2-D convolution.
913+
914+
Numeric casting is performed on the operands to the inner multiply, promoting
915+
them to the same data type as the accumulator/output.
916+
structured_op: !LinalgStructuredOpConfig
917+
args:
918+
- !LinalgOperandDefConfig
919+
name: I
920+
usage: InputOperand
921+
type_var: T1
922+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
923+
-> (s0, s1, s2, s3)>
924+
- !LinalgOperandDefConfig
925+
name: K
926+
usage: InputOperand
927+
type_var: T2
928+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
929+
-> (s4, s5, s3, s6)>
930+
- !LinalgOperandDefConfig
931+
name: O
932+
usage: OutputOperand
933+
type_var: U
934+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
935+
-> (s0, s7, s8, s3, s6)>
936+
- !LinalgOperandDefConfig
937+
name: strides
938+
usage: IndexAttribute
939+
type_var: I64
940+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
941+
s12] -> (s9, s10)>
942+
- !LinalgOperandDefConfig
943+
name: dilations
944+
usage: IndexAttribute
945+
type_var: I64
946+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
947+
s12] -> (s11, s12)>
948+
indexing_maps: !LinalgIndexingMapsConfig
949+
static_indexing_maps:
950+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
951+
s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)>
952+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
953+
s9, s10, s11, s12] -> (d3, d4, d5, d6)>
954+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
955+
s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)>
956+
iterator_types:
957+
- parallel
958+
- parallel
959+
- parallel
960+
- reduction
961+
- reduction
962+
- parallel
963+
- parallel
964+
assignments:
965+
- !ScalarAssign
966+
arg: O
967+
value: !ScalarExpression
968+
scalar_apply:
969+
fn_name: add
970+
operands:
971+
- !ScalarExpression
972+
scalar_arg: O
973+
- !ScalarExpression
974+
scalar_apply:
975+
fn_name: mul
976+
operands:
977+
- !ScalarExpression
978+
symbolic_cast:
979+
type_var: U
980+
operands:
981+
- !ScalarExpression
982+
scalar_arg: I
983+
- !ScalarExpression
984+
symbolic_cast:
985+
type_var: U
986+
operands:
987+
- !ScalarExpression
988+
scalar_arg: K
989+
--- !LinalgOpConfig
990+
metadata: !LinalgOpMetadata
991+
name: depthwise_conv2D_nchw_q
992+
cpp_class_name: DepthwiseConv2DNchwQOp
993+
doc: |-
994+
Performs depth-wise 2-D convolution.
995+
996+
Numeric casting is performed on the operands to the inner multiply, promoting
997+
them to the same data type as the accumulator/output.
998+
structured_op: !LinalgStructuredOpConfig
999+
args:
1000+
- !LinalgOperandDefConfig
1001+
name: I
1002+
usage: InputOperand
1003+
type_var: T1
1004+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
1005+
-> (s0, s1, s2, s3)>
1006+
- !LinalgOperandDefConfig
1007+
name: K
1008+
usage: InputOperand
1009+
type_var: T2
1010+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
1011+
-> (s4, s5, s3, s6)>
1012+
- !LinalgOperandDefConfig
1013+
name: IZp
1014+
usage: InputOperand
1015+
type_var: I32
1016+
- !LinalgOperandDefConfig
1017+
name: KZp
1018+
usage: InputOperand
1019+
type_var: I32
1020+
- !LinalgOperandDefConfig
1021+
name: O
1022+
usage: OutputOperand
1023+
type_var: U
1024+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
1025+
-> (s0, s7, s8, s3, s6)>
1026+
- !LinalgOperandDefConfig
1027+
name: strides
1028+
usage: IndexAttribute
1029+
type_var: I64
1030+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
1031+
s12] -> (s9, s10)>
1032+
- !LinalgOperandDefConfig
1033+
name: dilations
1034+
usage: IndexAttribute
1035+
type_var: I64
1036+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
1037+
s12] -> (s11, s12)>
1038+
indexing_maps: !LinalgIndexingMapsConfig
1039+
static_indexing_maps:
1040+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
1041+
s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)>
1042+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
1043+
s9, s10, s11, s12] -> (d3, d4, d5, d6)>
1044+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
1045+
s9, s10, s11, s12] -> ()>
1046+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
1047+
s9, s10, s11, s12] -> ()>
1048+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
1049+
s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)>
1050+
iterator_types:
1051+
- parallel
1052+
- parallel
1053+
- parallel
1054+
- reduction
1055+
- reduction
1056+
- parallel
1057+
- parallel
1058+
assignments:
1059+
- !ScalarAssign
1060+
arg: O
1061+
value: !ScalarExpression
1062+
scalar_apply:
1063+
fn_name: add
1064+
operands:
1065+
- !ScalarExpression
1066+
scalar_arg: O
1067+
- !ScalarExpression
1068+
scalar_apply:
1069+
fn_name: mul
1070+
operands:
1071+
- !ScalarExpression
1072+
scalar_apply:
1073+
fn_name: sub
1074+
operands:
1075+
- !ScalarExpression
1076+
symbolic_cast:
1077+
type_var: U
1078+
operands:
1079+
- !ScalarExpression
1080+
scalar_arg: I
1081+
- !ScalarExpression
1082+
symbolic_cast:
1083+
type_var: U
1084+
operands:
1085+
- !ScalarExpression
1086+
scalar_arg: IZp
1087+
- !ScalarExpression
1088+
scalar_apply:
1089+
fn_name: sub
1090+
operands:
1091+
- !ScalarExpression
1092+
symbolic_cast:
1093+
type_var: U
1094+
operands:
1095+
- !ScalarExpression
1096+
scalar_arg: K
1097+
- !ScalarExpression
1098+
symbolic_cast:
1099+
type_var: U
1100+
operands:
1101+
- !ScalarExpression
1102+
scalar_arg: KZp
1103+
--- !LinalgOpConfig
9081104
metadata: !LinalgOpMetadata
9091105
name: conv_2d_nchw
9101106
cpp_class_name: Conv2DNchwOp
@@ -1700,4 +1896,3 @@ structured_op: !LinalgStructuredOpConfig
17001896
operands:
17011897
- !ScalarExpression
17021898
scalar_arg: I
1703-

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -946,19 +946,31 @@ convolutionMatchAndRewriterHelper(Operation *op,
946946
return success();
947947
}
948948

949-
if (isa<tosa::DepthwiseConv2DOp>(op) && !isQuantized) {
949+
if (isa<tosa::DepthwiseConv2DOp>(op)) {
950950
ShapedType linalgConvTy =
951951
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
952952
weightShape[2], weightShape[3]},
953953
resultETy);
954954

955955
Value biasReshape =
956956
rewriter.create<tosa::ReshapeOp>(loc, linalgConvTy, biasBroadcast);
957-
Value conv = rewriter
958-
.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
959-
loc, linalgConvTy, ValueRange{input, weight},
960-
ValueRange{biasReshape}, dilationAttr, strideAttr)
961-
.getResult(0);
957+
Value conv;
958+
if (!isQuantized) {
959+
conv = rewriter
960+
.create<linalg::DepthwiseConv2DNchwOp>(
961+
loc, linalgConvTy, ValueRange{input, weight},
962+
ValueRange{biasReshape}, dilationAttr, strideAttr)
963+
.getResult(0);
964+
} else {
965+
auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
966+
auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
967+
conv =
968+
rewriter
969+
.create<linalg::DepthwiseConv2DNchwQOp>(
970+
loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
971+
ValueRange{biasReshape}, dilationAttr, strideAttr)
972+
.getResult(0);
973+
}
962974

963975
Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
964976
rewriter.replaceOp(op, reshape);

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,43 @@ def conv_2d_nchw(
223223
]) * cast(U, K[D.f, D.c, D.kh, D.kw])
224224

225225

226+
def depthwise_conv2D_nchw( #TODO: Fix name
227+
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
228+
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
229+
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
230+
strides=AttributeDef(S.SH, S.SW),
231+
dilations=AttributeDef(S.DH, S.DW)):
232+
"""Performs depth-wise 2-D convolution.
233+
234+
Numeric casting is performed on the operands to the inner multiply, promoting
235+
them to the same data type as the accumulator/output.
236+
"""
237+
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm)
238+
O[D.n, D.oh, D.ow, D.ic, D.cm] += cast(
239+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
240+
D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm])
241+
242+
243+
def depthwise_conv2D_nchw_q( #TODO: Fix name
244+
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
245+
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
246+
IZp=ScalarDef(I32),
247+
KZp=ScalarDef(I32),
248+
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
249+
strides=AttributeDef(S.SH, S.SW),
250+
dilations=AttributeDef(S.DH, S.DW)):
251+
"""Performs depth-wise 2-D convolution.
252+
253+
Numeric casting is performed on the operands to the inner multiply, promoting
254+
them to the same data type as the accumulator/output.
255+
"""
256+
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm)
257+
O[D.n, D.oh, D.ow, D.ic, D.cm] += (
258+
(cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
259+
D.ic]) - cast(U, IZp)) *
260+
(cast(U, K[D.kh, D.kw, D.ic, D.cm]) - cast(U, KZp)))
261+
262+
226263
@linalg_structured_op
227264
def pooling_nhwc_sum(
228265
I=TensorDef(T1, S.N, S.H, S.W, S.C),

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,7 @@ func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1xi8>,
12191219
// CHECK: linalg.yield %arg3 : i32
12201220
// CHECK: %[[C128:.+]] = constant -128
12211221
// CHECK: %[[C42:.+]] = constant 42
1222-
// CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, i32, i32) outs(%1 : tensor<1x10x10x1024xi32>)
1222+
// CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, i32, i32) outs(%1 : tensor<1x10x10x1024xi32>)
12231223
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x10x10x1024xi32>
12241224
return
12251225
}
@@ -1237,14 +1237,35 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
12371237
// CHECK: linalg.yield %arg3 : f32
12381238
// CHECK: } -> tensor<1x5x5x33xf32>
12391239
// CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
1240-
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
1240+
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2D_nchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
12411241
// CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]]
12421242
%2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>)
12431243
return
12441244
}
12451245

12461246
// -----
12471247

1248+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
1249+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1250+
1251+
// CHECK-LABEL: @depthwise_conv_quant
1252+
func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () {
1253+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 512]
1254+
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<512xi32>) outs([[INIT]] : tensor<1x10x10x512xi32>) {
1255+
// CHECK: ^bb0(%arg3: i32, %arg4: i32): // no predecessors
1256+
// CHECK: linalg.yield %arg3 : i32
1257+
// CHECK: } -> tensor<1x10x10x512xi32>
1258+
// CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
1259+
// CHECK: %[[C128:.+]] = constant -128
1260+
// CHECK: %[[C42:.+]] = constant 42
1261+
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[DBIAS]] : tensor<1x10x10x4x128xi32>)
1262+
// CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]]
1263+
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32>
1264+
return
1265+
}
1266+
1267+
// -----
1268+
12481269
// CHECK-LABEL: @transpose_conv
12491270
func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
12501271
// CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]

0 commit comments

Comments
 (0)