Skip to content

Commit 131e59a

Browse files
Tai78641FranklandJack
authored andcommitted
[mlir][tosa] Add acc_type to Tosa-v1.0 Conv Ops
Tosa v1.0 adds accumulator type attributes to the various convolution operations defined in the spec. Update the dialect and any lit tests to include these attributes. Signed-off-by: Tai Ly <[email protected]>
1 parent 998bdae commit 131e59a

File tree

15 files changed

+208
-133
lines changed

15 files changed

+208
-133
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,12 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
126126
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
127127
"::mlir::Value":$weight, "::mlir::Value":$bias,
128128
"::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
129-
"::mlir::DenseI64ArrayAttr":$dilation),
129+
"::mlir::DenseI64ArrayAttr":$dilation,
130+
"::mlir::TypeAttr":$acc_type),
130131
[{
131132
buildConvOpWithQuantInfo($_builder, $_state, outputType,
132133
input, weight, bias,
133-
pad, stride, dilation);
134+
pad, stride, dilation, acc_type);
134135
}]>;
135136

136137
// Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
@@ -139,12 +140,13 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
139140
"::mlir::Value":$weight, "mlir::Value":$bias,
140141
"::mlir::DenseI64ArrayAttr":$outpad,
141142
"::mlir::DenseI64ArrayAttr":$stride,
142-
"::mlir::DenseI64ArrayAttr":$outputShape),
143+
"::mlir::DenseI64ArrayAttr":$outputShape,
144+
"::mlir::TypeAttr":$acc_type),
143145
[{
144146
buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
145147
input, weight, bias,
146148
outpad, stride,
147-
outputShape);
149+
outputShape, acc_type);
148150
}]>;
149151

150152
// The tosa.fully_connected op has its own builder as it does not have

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
5757
// Accumulator types.
5858
//===----------------------------------------------------------------------===//
5959

60-
def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
60+
def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
6161

6262
//===----------------------------------------------------------------------===//
6363
// Operator: avg_pool2d
@@ -106,6 +106,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
106106
Tosa_IntArrayAttr4:$pad,
107107
Tosa_IntArrayAttr2:$stride,
108108
Tosa_IntArrayAttr2:$dilation,
109+
TypeAttrOf<Tosa_AccType>:$acc_type,
109110
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
110111
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
111112
);
@@ -135,6 +136,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
135136
Tosa_IntArrayAttr6:$pad,
136137
Tosa_IntArrayAttr3:$stride,
137138
Tosa_IntArrayAttr3:$dilation,
139+
TypeAttrOf<Tosa_AccType>:$acc_type,
138140
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
139141
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
140142
);
@@ -165,6 +167,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
165167
Tosa_IntArrayAttr4:$pad,
166168
Tosa_IntArrayAttr2:$stride,
167169
Tosa_IntArrayAttr2:$dilation,
170+
TypeAttrOf<Tosa_AccType>:$acc_type,
168171
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
169172
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
170173
);
@@ -348,6 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
348351
Tosa_IntArrayAttr4:$out_pad,
349352
Tosa_IntArrayAttr2:$stride,
350353
Tosa_IntArrayAttr4:$out_shape,
354+
TypeAttrOf<Tosa_AccType>:$acc_type,
351355
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
352356
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
353357
);

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,55 @@ LogicalResult tosa::ConstOp::verify() {
271271
return success();
272272
}
273273

274+
template <typename T>
275+
static LogicalResult verifyConvOpModes(T op) {
276+
auto inputEType =
277+
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
278+
279+
if (auto quantType =
280+
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
281+
inputEType = quantType.getStorageType();
282+
283+
auto accType = op.getAccType();
284+
if (inputEType.isInteger(8) && !accType.isInteger(32))
285+
return op.emitOpError("accumulator type for i8 tensor is not i32");
286+
287+
if (inputEType.isInteger(16) && !accType.isInteger(48))
288+
return op.emitOpError("accumulator type for i16 tensor is not i48");
289+
290+
if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN()) &&
291+
!accType.isF16())
292+
return op.emitOpError("accumulator type for f8 tensor is not f16");
293+
294+
if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
295+
return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
296+
297+
if (inputEType.isBF16() && !accType.isF32())
298+
return op.emitOpError("accumulator type for bf16 tensor is not f32");
299+
300+
if (inputEType.isF32() && !accType.isF32())
301+
return op.emitOpError("accumulator type for f32 tensor is not f32");
302+
303+
auto resultEType =
304+
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
305+
306+
if (auto quantType =
307+
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
308+
resultEType = quantType.getStorageType();
309+
310+
// check allowed input/result element types combinations
311+
if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
312+
(inputEType.isInteger(16) && resultEType.isInteger(48)) ||
313+
(inputEType.isFloat8E5M2() && resultEType.isF16()) ||
314+
(inputEType.isFloat8E4M3FN() && resultEType.isF16()) ||
315+
(inputEType.isF16() && resultEType.isF16()) ||
316+
(inputEType.isBF16() && resultEType.isBF16()) ||
317+
(inputEType.isF32() && resultEType.isF32()))
318+
return success();
319+
320+
return op.emitOpError("input/output element types are incompatible.");
321+
}
322+
274323
LogicalResult tosa::ArgMaxOp::verify() {
275324
// Ensure output is of 32-bit integer
276325
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -368,12 +417,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368417
Type outputType, Value input, Value weight,
369418
Value bias, DenseI64ArrayAttr pad,
370419
DenseI64ArrayAttr stride,
371-
DenseI64ArrayAttr dilation) {
420+
DenseI64ArrayAttr dilation,
421+
TypeAttr accType) {
372422

373423
result.addOperands({input, weight, bias});
374424
result.addAttribute("pad", pad);
375425
result.addAttribute("stride", stride);
376426
result.addAttribute("dilation", dilation);
427+
result.addAttribute("acc_type", accType);
377428

378429
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
379430
if (quantAttr) {
@@ -390,11 +441,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390441
static void buildTransConvOpWithQuantInfo(
391442
OpBuilder &builder, OperationState &result, Type outputType, Value input,
392443
Value weight, Value bias, DenseI64ArrayAttr outpad,
393-
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
444+
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
394445
result.addOperands({input, weight, bias});
395446
result.addAttribute("out_pad", outpad);
396447
result.addAttribute("stride", stride);
397448
result.addAttribute("out_shape", outputShape);
449+
result.addAttribute("acc_type", accType);
398450
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
399451

400452
if (quantAttr) {
@@ -1595,7 +1647,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
15951647
return success();
15961648
}
15971649

1598-
LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1650+
LogicalResult Conv2DOp::verify() {
1651+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1652+
return failure();
1653+
return success();
1654+
}
15991655

16001656
LogicalResult Conv3DOp::inferReturnTypeComponents(
16011657
MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1723,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
16671723
return success();
16681724
}
16691725

1670-
LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1726+
LogicalResult Conv3DOp::verify() {
1727+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1728+
return failure();
1729+
return success();
1730+
}
16711731

16721732
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
16731733
MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1822,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
17621822
return success();
17631823
}
17641824

1765-
LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
1825+
LogicalResult DepthwiseConv2DOp::verify() {
1826+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1827+
return failure();
1828+
return success();
1829+
}
17661830

17671831
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
17681832
MLIRContext *context, ::std::optional<Location> location,

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,15 @@ class TransposeConvNonStridedConverter
7575
loc, resultTy, input, reverse2, bias,
7676
rewriter.getDenseI64ArrayAttr(convPad),
7777
rewriter.getDenseI64ArrayAttr(stride),
78-
rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
78+
rewriter.getDenseI64ArrayAttr({1, 1}),
79+
/* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
7980
} else {
8081
conv2d = rewriter.create<tosa::Conv2DOp>(
8182
loc, resultTy, input, reverse2, bias,
8283
rewriter.getDenseI64ArrayAttr(convPad),
8384
rewriter.getDenseI64ArrayAttr(stride),
84-
rewriter.getDenseI64ArrayAttr({1, 1}));
85+
rewriter.getDenseI64ArrayAttr({1, 1}),
86+
/* acc_type = */ op.getAccTypeAttr());
8587
}
8688

8789
rewriter.replaceOp(op, conv2d);
@@ -238,15 +240,16 @@ class TransposeConvStridedConverter
238240
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
239241
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
240242
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
241-
*op.getQuantizationInfo())
243+
/* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
242244
.getResult();
243245
} else {
244246
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
245247
rewriter, loc, UnrankedTensorType::get(resultETy), input,
246248
weight, zeroBias,
247249
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
248250
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
249-
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
251+
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
252+
/* acc_type = */ op.getAccTypeAttr())
250253
.getResult();
251254
}
252255

0 commit comments

Comments
 (0)