Skip to content

Commit 1a8d2a4

Browse files
Tai78641Tessil
andauthored
[mlir][tosa] Use generic QuantizedType in Conv verifiers (#126275)
Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers. Change-Id: Ie1961af931864f801914a62976bc988881ee075e Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Thibaut Goetghebuer-Planchon <[email protected]>
1 parent 06f4fe3 commit 1a8d2a4

File tree

3 files changed

+55
-9
lines changed

3 files changed

+55
-9
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
240240
bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241241
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242242

243-
if (auto quantType =
244-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
243+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
245244
inputEType = quantType.getStorageType();
246245

247-
if (auto quantType =
248-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
246+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
249247
biasEType = quantType.getStorageType();
250248

251-
if (auto quantType =
252-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
249+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
253250
resultEType = quantType.getStorageType();
254251

255252
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
346343
auto inputEType =
347344
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
348345

349-
if (auto quantType =
350-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
346+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
351347
inputEType = quantType.getStorageType();
352348

353349
auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
369365
if (inputEType.isF32() && !accType.isF32())
370366
return op.emitOpError("accumulator type for f32 tensor is not f32");
371367

372-
return success();
368+
auto resultEType =
369+
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
370+
371+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372+
resultEType = quantType.getStorageType();
373+
374+
// check allowed input/result element types combinations
375+
if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
376+
(inputEType.isInteger(16) && resultEType.isInteger(48)) ||
377+
(isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
378+
(isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
379+
(inputEType.isF16() && resultEType.isF16()) ||
380+
(inputEType.isBF16() && resultEType.isBF16()) ||
381+
(inputEType.isF32() && resultEType.isF32()))
382+
return success();
383+
384+
return op.emitOpError("input/output element types are incompatible.");
373385
}
374386

375387
// verify that inType and outType have same element types

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,24 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1
144144
return %0 : tensor<1x32x32x16xi8>
145145
}
146146

147+
// -----
148+
// CHECK-LABEL: conv2d_quant_any_acc
149+
func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>> {
150+
%zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
151+
// expected-error@+1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
152+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>>
153+
return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
154+
}
155+
156+
// -----
157+
// CHECK-LABEL: conv2d_quant_any_result
158+
func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>> {
159+
%zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
160+
// expected-error@+1 {{'tosa.conv2d' op input/output element types are incompatible}}
161+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>>
162+
return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
163+
}
164+
147165
// -----
148166

149167
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,22 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
5858
return %0 : tensor<1x4x4x8xf32>
5959
}
6060

61+
// -----
62+
// CHECK-LABEL: conv2d_quant_uniform
63+
func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, %arg2: tensor<8x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>> {
64+
%zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
65+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, tensor<8x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>>
66+
return %0 : tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>>
67+
}
68+
69+
// -----
70+
// CHECK-LABEL: conv2d_quant_any
71+
func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>> {
72+
%zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
73+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>>
74+
return %0 : tensor<1x4x4x8x!quant.any<i32<-8:7>>>
75+
}
76+
6177
// -----
6278
// CHECK-LABEL: conv2d_q8xi4
6379
func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {

0 commit comments

Comments
 (0)