Skip to content

Commit 1c45514

Browse files
authored
[mlir][tosa] Fix bug causing quantized pad const creation crash (#131125)
This commit ensures the storage type is retrieved correctly which fixes a crash when creating a quantized pad const tensor. Testing is completed via the `tosa-optional-decompositions` pass which makes use of the `createPadConstTensor` function. Also includes some cleanup.
1 parent c30ff92 commit 1c45514

File tree

2 files changed

+38
-25
lines changed

2 files changed

+38
-25
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,6 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
216216
}
217217
}
218218

219-
// Create a pad-const const tensor with value of `val` of required data-type
220-
Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
221-
Value src, int32_t val) {
222-
const auto srcType = getElementTypeOrSelf(src);
223-
const auto srcElemType = getElementTypeOrSelf(src);
224-
const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
225-
const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
226-
const auto padConstAttr{
227-
llvm::isa<FloatType>(srcElemType)
228-
? DenseElementsAttr::get(padConstEType,
229-
builder.getFloatAttr(srcElemType, val))
230-
: DenseElementsAttr::get(padConstEType,
231-
builder.getIntegerAttr(srcElemType, val))};
232-
return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
233-
}
234-
235219
//===----------------------------------------------------------------------===//
236220
// Tosa utilities.
237221
//===----------------------------------------------------------------------===//
@@ -242,16 +226,15 @@ std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
242226
return lhs / rhs;
243227
}
244228

245-
//===----------------------------------------------------------------------===//
246-
// Tosa utilities.
247-
//===----------------------------------------------------------------------===//
248-
249-
static Type getStorageElementTypeOrSelf(Type type) {
250-
auto elementType = getElementTypeOrSelf(type);
251-
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(elementType))
252-
elementType = quantType.getStorageType();
229+
Type getStorageElementTypeOrSelf(Type type) {
230+
auto srcType = getElementTypeOrSelf(type);
231+
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
232+
srcType = quantType.getStorageType();
233+
return srcType;
234+
}
253235

254-
return elementType;
236+
Type getStorageElementTypeOrSelf(Value value) {
237+
return getStorageElementTypeOrSelf(value.getType());
255238
}
256239

257240
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
@@ -273,6 +256,22 @@ static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
273256
return success();
274257
}
275258

259+
// Create a pad-const const tensor with value of `val` of required data-type
260+
Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
261+
Value src, int32_t val) {
262+
const auto srcType = getElementTypeOrSelf(src);
263+
const auto srcElemType = getStorageElementTypeOrSelf(src);
264+
const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
265+
const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
266+
const auto padConstAttr{
267+
llvm::isa<FloatType>(srcElemType)
268+
? DenseElementsAttr::get(padConstEType,
269+
builder.getFloatAttr(srcElemType, val))
270+
: DenseElementsAttr::get(padConstEType,
271+
builder.getIntegerAttr(srcElemType, val))};
272+
return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
273+
}
274+
276275
//===----------------------------------------------------------------------===//
277276
// TOSA Operator Verifiers.
278277
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
131131

132132
// -----
133133

134+
// CHECK-LABEL: @transpose_conv2d_strided_quantized_quant_input
135+
func.func @transpose_conv2d_strided_quantized_quant_input(%arg0: tensor<2x17x15x3x!quant.uniform<i8:f32, 0.015684274956583977:-1>>, %arg1: tensor<5x3x5x3x!quant.uniform<i8:f32, 0.015684274956583977:-1>>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
136+
// Checks a regression. A typo in `createPadConstTensor` caused the conversion to crash
137+
// CHECK-DAG: %[[PAD_SHAPE:.+]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
138+
// CHECK-DAG: %[[PAD_CONST:.+]] = "tosa.const"() <{values = dense<-22> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 0.015684274956583977:-1>>
139+
// CHECK: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_CONST]] : (tensor<2x17x15x3x!quant.uniform<i8:f32, 0.015684274956583977:-1>>, !tosa.shape<8>, tensor<1x!quant.uniform<i8:f32, 0.015684274956583977:-1>>)
140+
%input_zp = "tosa.const"() <{values = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
141+
%weight_zp = "tosa.const"() <{values = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
142+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3x!quant.uniform<i8:f32, 0.015684274956583977:-1>>, tensor<5x3x5x3x!quant.uniform<i8:f32, 0.015684274956583977:-1>>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x35x47x5xi32>
143+
return %0 : tensor<2x35x47x5xi32>
144+
}
145+
146+
// -----
147+
134148
// CHECK-LABEL: @transpose_conv2d_strided_overpad
135149
func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
136150
// CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>

0 commit comments

Comments
 (0)