Skip to content

Commit e817a37

Browse files
Tai78641FranklandJack
authored andcommitted
[mlir][tosa] Add acc_type to Tosa-v1.0 Conv Ops
Tosa v1.0 adds accumulator type attributes to the various convolution operations defined in the spec. Update the dialect and any lit tests to include these attributes. Signed-off-by: Tai Ly <[email protected]>
1 parent 998bdae commit e817a37

File tree

15 files changed

+299
-135
lines changed

15 files changed

+299
-135
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,12 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
126126
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
127127
"::mlir::Value":$weight, "::mlir::Value":$bias,
128128
"::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
129-
"::mlir::DenseI64ArrayAttr":$dilation),
129+
"::mlir::DenseI64ArrayAttr":$dilation,
130+
"::mlir::TypeAttr":$acc_type),
130131
[{
131132
buildConvOpWithQuantInfo($_builder, $_state, outputType,
132133
input, weight, bias,
133-
pad, stride, dilation);
134+
pad, stride, dilation, acc_type);
134135
}]>;
135136

136137
// Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
@@ -139,12 +140,13 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
139140
"::mlir::Value":$weight, "mlir::Value":$bias,
140141
"::mlir::DenseI64ArrayAttr":$outpad,
141142
"::mlir::DenseI64ArrayAttr":$stride,
142-
"::mlir::DenseI64ArrayAttr":$outputShape),
143+
"::mlir::DenseI64ArrayAttr":$outputShape,
144+
"::mlir::TypeAttr":$acc_type),
143145
[{
144146
buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
145147
input, weight, bias,
146148
outpad, stride,
147-
outputShape);
149+
outputShape, acc_type);
148150
}]>;
149151

150152
// The tosa.fully_connected op has its own builder as it does not have

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
5757
// Accumulator types.
5858
//===----------------------------------------------------------------------===//
5959

60-
def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
60+
def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
6161

6262
//===----------------------------------------------------------------------===//
6363
// Operator: avg_pool2d
@@ -106,6 +106,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
106106
Tosa_IntArrayAttr4:$pad,
107107
Tosa_IntArrayAttr2:$stride,
108108
Tosa_IntArrayAttr2:$dilation,
109+
TypeAttrOf<Tosa_AccType>:$acc_type,
109110
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
110111
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
111112
);
@@ -135,6 +136,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
135136
Tosa_IntArrayAttr6:$pad,
136137
Tosa_IntArrayAttr3:$stride,
137138
Tosa_IntArrayAttr3:$dilation,
139+
TypeAttrOf<Tosa_AccType>:$acc_type,
138140
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
139141
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
140142
);
@@ -165,6 +167,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
165167
Tosa_IntArrayAttr4:$pad,
166168
Tosa_IntArrayAttr2:$stride,
167169
Tosa_IntArrayAttr2:$dilation,
170+
TypeAttrOf<Tosa_AccType>:$acc_type,
168171
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
169172
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
170173
);
@@ -348,6 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
348351
Tosa_IntArrayAttr4:$out_pad,
349352
Tosa_IntArrayAttr2:$stride,
350353
Tosa_IntArrayAttr4:$out_shape,
354+
TypeAttrOf<Tosa_AccType>:$acc_type,
351355
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
352356
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
353357
);
@@ -357,6 +361,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
357361
);
358362

359363
let builders = [Tosa_TransConvOpQuantInfoBuilder];
364+
let hasVerifier = 1;
360365
}
361366

362367
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,26 @@ template <typename T>
210210
static LogicalResult verifyConvOp(T op) {
211211
// All TOSA conv ops have an input() and weight().
212212
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
213-
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
213+
214+
RankedTensorType weightType;
215+
if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
216+
weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
217+
else
218+
weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
214219

215220
// Must be ranked tensor types
216221
if (!inputType) {
217222
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
218223
return failure();
219224
}
220225
if (!weightType) {
221-
op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
226+
if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
227+
op.emitOpError("expect a ranked tensor for filter, got ")
228+
<< op.getFilter();
229+
} else {
230+
op.emitOpError("expect a ranked tensor for weight, got ")
231+
<< op.getWeight();
232+
}
222233
return failure();
223234
}
224235

@@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() {
271282
return success();
272283
}
273284

285+
template <typename T>
286+
static LogicalResult verifyConvOpModes(T op) {
287+
auto inputEType =
288+
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
289+
290+
if (auto quantType =
291+
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
292+
inputEType = quantType.getStorageType();
293+
294+
auto accType = op.getAccType();
295+
if (inputEType.isInteger(8) && !accType.isInteger(32))
296+
return op.emitOpError("accumulator type for i8 tensor is not i32");
297+
298+
if (inputEType.isInteger(16) && !accType.isInteger(48))
299+
return op.emitOpError("accumulator type for i16 tensor is not i48");
300+
301+
if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3()) &&
302+
!accType.isF16())
303+
return op.emitOpError("accumulator type for f8 tensor is not f16");
304+
305+
if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
306+
return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
307+
308+
if (inputEType.isBF16() && !accType.isF32())
309+
return op.emitOpError("accumulator type for bf16 tensor is not f32");
310+
311+
if (inputEType.isF32() && !accType.isF32())
312+
return op.emitOpError("accumulator type for f32 tensor is not f32");
313+
314+
return success();
315+
}
316+
274317
LogicalResult tosa::ArgMaxOp::verify() {
275318
// Ensure output is of 32-bit integer
276319
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368411
Type outputType, Value input, Value weight,
369412
Value bias, DenseI64ArrayAttr pad,
370413
DenseI64ArrayAttr stride,
371-
DenseI64ArrayAttr dilation) {
414+
DenseI64ArrayAttr dilation,
415+
TypeAttr accType) {
372416

373417
result.addOperands({input, weight, bias});
374418
result.addAttribute("pad", pad);
375419
result.addAttribute("stride", stride);
376420
result.addAttribute("dilation", dilation);
421+
result.addAttribute("acc_type", accType);
377422

378423
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
379424
if (quantAttr) {
@@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
390435
static void buildTransConvOpWithQuantInfo(
391436
OpBuilder &builder, OperationState &result, Type outputType, Value input,
392437
Value weight, Value bias, DenseI64ArrayAttr outpad,
393-
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
438+
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
394439
result.addOperands({input, weight, bias});
395440
result.addAttribute("out_pad", outpad);
396441
result.addAttribute("stride", stride);
397442
result.addAttribute("out_shape", outputShape);
443+
result.addAttribute("acc_type", accType);
398444
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
399445

400446
if (quantAttr) {
@@ -1595,7 +1641,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
15951641
return success();
15961642
}
15971643

1598-
LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1644+
LogicalResult Conv2DOp::verify() {
1645+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1646+
return failure();
1647+
return success();
1648+
}
15991649

16001650
LogicalResult Conv3DOp::inferReturnTypeComponents(
16011651
MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1717,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
16671717
return success();
16681718
}
16691719

1670-
LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1720+
LogicalResult Conv3DOp::verify() {
1721+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1722+
return failure();
1723+
return success();
1724+
}
16711725

16721726
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
16731727
MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1816,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
17621816
return success();
17631817
}
17641818

1765-
LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
1819+
LogicalResult DepthwiseConv2DOp::verify() {
1820+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1821+
return failure();
1822+
return success();
1823+
}
17661824

17671825
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
17681826
MLIRContext *context, ::std::optional<Location> location,
@@ -1828,6 +1886,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
18281886
return success();
18291887
}
18301888

1889+
LogicalResult TransposeConv2DOp::verify() {
1890+
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1891+
return failure();
1892+
return success();
1893+
}
1894+
18311895
LogicalResult IfOp::inferReturnTypeComponents(
18321896
MLIRContext *context, ::std::optional<Location> location,
18331897
IfOp::Adaptor adaptor,

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,15 @@ class TransposeConvNonStridedConverter
7575
loc, resultTy, input, reverse2, bias,
7676
rewriter.getDenseI64ArrayAttr(convPad),
7777
rewriter.getDenseI64ArrayAttr(stride),
78-
rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
78+
rewriter.getDenseI64ArrayAttr({1, 1}),
79+
/* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
7980
} else {
8081
conv2d = rewriter.create<tosa::Conv2DOp>(
8182
loc, resultTy, input, reverse2, bias,
8283
rewriter.getDenseI64ArrayAttr(convPad),
8384
rewriter.getDenseI64ArrayAttr(stride),
84-
rewriter.getDenseI64ArrayAttr({1, 1}));
85+
rewriter.getDenseI64ArrayAttr({1, 1}),
86+
/* acc_type = */ op.getAccTypeAttr());
8587
}
8688

8789
rewriter.replaceOp(op, conv2d);
@@ -238,15 +240,16 @@ class TransposeConvStridedConverter
238240
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
239241
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
240242
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
241-
*op.getQuantizationInfo())
243+
/* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
242244
.getResult();
243245
} else {
244246
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
245247
rewriter, loc, UnrankedTensorType::get(resultETy), input,
246248
weight, zeroBias,
247249
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
248250
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
249-
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
251+
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
252+
/* acc_type = */ op.getAccTypeAttr())
250253
.getResult();
251254
}
252255

0 commit comments

Comments
 (0)