Skip to content

Commit 5152221

Browse files
Tai78641FranklandJack
authored andcommitted
[mlir][tosa] Add acc_type to Tosa-v1.0 Conv Ops
Tosa v1.0 adds accumulator type attributes to the various convolution operations defined in the spec. Update the dialect and any lit tests to include these attributes. Signed-off-by: Tai Ly <[email protected]>
1 parent a3744f0 commit 5152221

File tree

17 files changed

+201
-130
lines changed

17 files changed

+201
-130
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,12 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
126126
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
127127
"::mlir::Value":$weight, "::mlir::Value":$bias,
128128
"::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
129-
"::mlir::DenseI64ArrayAttr":$dilation),
129+
"::mlir::DenseI64ArrayAttr":$dilation,
130+
"::mlir::TypeAttr":$acc_type),
130131
[{
131132
buildConvOpWithQuantInfo($_builder, $_state, outputType,
132133
input, weight, bias,
133-
pad, stride, dilation);
134+
pad, stride, dilation, acc_type);
134135
}]>;
135136

136137
// Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
@@ -139,12 +140,13 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
139140
"::mlir::Value":$weight, "mlir::Value":$bias,
140141
"::mlir::DenseI64ArrayAttr":$outpad,
141142
"::mlir::DenseI64ArrayAttr":$stride,
142-
"::mlir::DenseI64ArrayAttr":$outputShape),
143+
"::mlir::DenseI64ArrayAttr":$outputShape,
144+
"::mlir::TypeAttr":$acc_type),
143145
[{
144146
buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
145147
input, weight, bias,
146148
outpad, stride,
147-
outputShape);
149+
outputShape, acc_type);
148150
}]>;
149151

150152
// The tosa.fully_connected op has its own builder as it does not have

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
5757
// Accumulator types.
5858
//===----------------------------------------------------------------------===//
5959

60-
def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
60+
def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
6161

6262
//===----------------------------------------------------------------------===//
6363
// Operator: avg_pool2d
@@ -106,6 +106,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
106106
Tosa_IntArrayAttr4:$pad,
107107
Tosa_IntArrayAttr2:$stride,
108108
Tosa_IntArrayAttr2:$dilation,
109+
TypeAttrOf<Tosa_AccType>:$acc_type,
109110
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
110111
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
111112
);
@@ -135,6 +136,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
135136
Tosa_IntArrayAttr6:$pad,
136137
Tosa_IntArrayAttr3:$stride,
137138
Tosa_IntArrayAttr3:$dilation,
139+
TypeAttrOf<Tosa_AccType>:$acc_type,
138140
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
139141
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
140142
);
@@ -165,6 +167,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
165167
Tosa_IntArrayAttr4:$pad,
166168
Tosa_IntArrayAttr2:$stride,
167169
Tosa_IntArrayAttr2:$dilation,
170+
TypeAttrOf<Tosa_AccType>:$acc_type,
168171
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
169172
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
170173
);
@@ -348,6 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
348351
Tosa_IntArrayAttr4:$out_pad,
349352
Tosa_IntArrayAttr2:$stride,
350353
Tosa_IntArrayAttr4:$out_shape,
354+
TypeAttrOf<Tosa_AccType>:$acc_type,
351355
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
352356
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
353357
);

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
9696
criteria, e.g. TOSA profile.
9797
}];
9898

99+
let dependentDialects = [
100+
"func::FuncDialect",
101+
"tensor::TensorDialect",
102+
"tosa::TosaDialect",
103+
];
104+
99105
let options = [
100106
ListOption<"profile", "profile", "std::string",
101107
"Validate if operations match for the given profile set">,

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368368
Type outputType, Value input, Value weight,
369369
Value bias, DenseI64ArrayAttr pad,
370370
DenseI64ArrayAttr stride,
371-
DenseI64ArrayAttr dilation) {
371+
DenseI64ArrayAttr dilation,
372+
TypeAttr accType) {
372373

373374
result.addOperands({input, weight, bias});
374375
result.addAttribute("pad", pad);
375376
result.addAttribute("stride", stride);
376377
result.addAttribute("dilation", dilation);
378+
result.addAttribute("acc_type", accType);
377379

378380
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
379381
if (quantAttr) {
@@ -390,11 +392,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390392
static void buildTransConvOpWithQuantInfo(
391393
OpBuilder &builder, OperationState &result, Type outputType, Value input,
392394
Value weight, Value bias, DenseI64ArrayAttr outpad,
393-
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
395+
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
394396
result.addOperands({input, weight, bias});
395397
result.addAttribute("out_pad", outpad);
396398
result.addAttribute("stride", stride);
397399
result.addAttribute("out_shape", outputShape);
400+
result.addAttribute("acc_type", accType);
398401
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
399402

400403
if (quantAttr) {

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,15 @@ class TransposeConvNonStridedConverter
7575
loc, resultTy, input, reverse2, bias,
7676
rewriter.getDenseI64ArrayAttr(convPad),
7777
rewriter.getDenseI64ArrayAttr(stride),
78-
rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
78+
rewriter.getDenseI64ArrayAttr({1, 1}),
79+
/* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
7980
} else {
8081
conv2d = rewriter.create<tosa::Conv2DOp>(
8182
loc, resultTy, input, reverse2, bias,
8283
rewriter.getDenseI64ArrayAttr(convPad),
8384
rewriter.getDenseI64ArrayAttr(stride),
84-
rewriter.getDenseI64ArrayAttr({1, 1}));
85+
rewriter.getDenseI64ArrayAttr({1, 1}),
86+
/* acc_type = */ op.getAccTypeAttr());
8587
}
8688

8789
rewriter.replaceOp(op, conv2d);
@@ -238,15 +240,16 @@ class TransposeConvStridedConverter
238240
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
239241
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
240242
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
241-
*op.getQuantizationInfo())
243+
/* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
242244
.getResult();
243245
} else {
244246
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
245247
rewriter, loc, UnrankedTensorType::get(resultETy), input,
246248
weight, zeroBias,
247249
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
248250
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
249-
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
251+
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
252+
/* acc_type = */ op.getAccTypeAttr())
250253
.getResult();
251254
}
252255

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818

1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
20+
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
2021
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
2122
#include "mlir/IR/Builders.h"
2223
#include "mlir/IR/BuiltinOps.h"
@@ -74,6 +75,55 @@ static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
7475
return success();
7576
}
7677

78+
template <typename T>
79+
static LogicalResult verifyConvOpModes(T op) {
80+
auto inputEType =
81+
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
82+
83+
if (auto quantType =
84+
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
85+
inputEType = quantType.getStorageType();
86+
87+
auto accType = op.getAccType();
88+
if (inputEType.isInteger(8) && !accType.isInteger(32))
89+
return op.emitOpError("accumulator type for i8 tensor is not i32");
90+
91+
if (inputEType.isInteger(16) && !accType.isInteger(48))
92+
return op.emitOpError("accumulator type for i16 tensor is not i48");
93+
94+
if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN()) &&
95+
!accType.isF16())
96+
return op.emitOpError("accumulator type for f8 tensor is not f16");
97+
98+
if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
99+
return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
100+
101+
if (inputEType.isBF16() && !accType.isF32())
102+
return op.emitOpError("accumulator type for bf16 tensor is not f32");
103+
104+
if (inputEType.isF32() && !accType.isF32())
105+
return op.emitOpError("accumulator type for f32 tensor is not f32");
106+
107+
auto resultEType =
108+
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
109+
110+
if (auto quantType =
111+
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
112+
resultEType = quantType.getStorageType();
113+
114+
// check allowed input/result element types combinations
115+
if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
116+
(inputEType.isInteger(16) && resultEType.isInteger(48)) ||
117+
(inputEType.isFloat8E5M2() && resultEType.isF16()) ||
118+
(inputEType.isFloat8E4M3FN() && resultEType.isF16()) ||
119+
(inputEType.isF16() && resultEType.isF16()) ||
120+
(inputEType.isBF16() && resultEType.isBF16()) ||
121+
(inputEType.isF32() && resultEType.isF32()))
122+
return success();
123+
124+
return op.emitOpError("input/output element types are incompatible.");
125+
}
126+
77127
struct TosaLevel {
78128
int32_t MAX_RANK = 0;
79129
int32_t MAX_KERNEL = 0;
@@ -331,6 +381,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
331381
return false;
332382
}
333383
}
384+
return verifyConvOpModes(convOp).succeeded();
334385
}
335386
return true;
336387
}

0 commit comments

Comments
 (0)