@@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
259
259
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
260
260
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
261
261
262
- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
263
- if (llvm::failed (failureOrMaybeZps))
264
- return failure ();
262
+ // Get and verify zero points.
263
+ int64_t inputZpVal;
264
+ int64_t weightZpVal;
265
+
266
+ if (op.getInputZeroPoint (inputZpVal).failed () ||
267
+ op.getWeightZeroPoint (weightZpVal).failed ())
268
+ return rewriter.notifyMatchFailure (
269
+ op, " bail out if zero points cannot statically be determined" );
270
+
271
+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
272
+ op.verifyWeightZeroPoint (weightZpVal).failed ())
273
+ return rewriter.notifyMatchFailure (
274
+ op, " zero point must be zero for non-int8 integer types" );
265
275
266
- auto maybeZps = failureOrMaybeZps. value ( );
276
+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
267
277
268
278
if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
269
279
return rewriter.notifyMatchFailure (
@@ -289,19 +299,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
289
299
290
300
// Apply padding as necessary.
291
301
TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
292
- if (maybeZps ) {
302
+ if (hasZp ) {
293
303
int64_t intMin =
294
304
APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
295
305
.getSExtValue ();
296
306
int64_t intMax =
297
307
APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
298
308
.getSExtValue ();
299
309
300
- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
310
+ if (inputZpVal < intMin || inputZpVal > intMax)
301
311
return rewriter.notifyMatchFailure (
302
312
op, " tosa.conv op quantization has zp outside of input range" );
303
313
304
- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
314
+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
305
315
}
306
316
307
317
llvm::SmallVector<int64_t > pad;
@@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
314
324
// For 2D convolutions, we need to check if the target convolution op
315
325
// wants a HWCF kernel layout.
316
326
bool wantHwcf =
317
- maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
318
- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
327
+ hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
328
+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
319
329
if (wantHwcf) {
320
330
// Transpose the kernel to match dimension ordering of the linalg
321
331
// convolution operation.
@@ -376,9 +386,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
376
386
Value broadcastBias =
377
387
linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
378
388
379
- if (maybeZps ) {
380
- auto iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
381
- auto kZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
389
+ if (hasZp ) {
390
+ auto iZp = rewriter.getI32IntegerAttr (inputZpVal );
391
+ auto kZp = rewriter.getI32IntegerAttr (weightZpVal );
382
392
383
393
auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
384
394
auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
@@ -441,31 +451,40 @@ class DepthwiseConvConverter
441
451
/* inputSizeDims=*/ {1 , 2 },
442
452
/* kernelSizeDims=*/ {0 , 1 }, rewriter);
443
453
444
- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
445
- if (llvm::failed (failureOrMaybeZps))
446
- return failure ();
454
+ // Get and verify zero points.
455
+ int64_t inputZpVal;
456
+ int64_t weightZpVal;
457
+
458
+ if (op.getInputZeroPoint (inputZpVal).failed () ||
459
+ op.getWeightZeroPoint (weightZpVal).failed ())
460
+ return rewriter.notifyMatchFailure (
461
+ op, " bail out if zero points cannot statically be determined" );
447
462
448
- auto maybeZps = failureOrMaybeZps.value ();
463
+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
464
+ op.verifyWeightZeroPoint (weightZpVal).failed ())
465
+ return rewriter.notifyMatchFailure (
466
+ op, " zero point must be zero for non-int8 integer types" );
449
467
468
+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
450
469
auto weightShape = weightTy.getShape ();
451
470
auto resultShape = resultTy.getShape ();
452
471
453
472
// Apply padding as necessary.
454
473
TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
455
- if (maybeZps ) {
474
+ if (hasZp ) {
456
475
int64_t intMin =
457
476
APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
458
477
.getSExtValue ();
459
478
int64_t intMax =
460
479
APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
461
480
.getSExtValue ();
462
481
463
- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
482
+ if (inputZpVal < intMin || inputZpVal > intMax)
464
483
return rewriter.notifyMatchFailure (
465
484
op, " tosa.depthwise_conv op quantization has zp outside of input "
466
485
" range" );
467
486
468
- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
487
+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
469
488
}
470
489
471
490
llvm::SmallVector<int64_t > pad;
@@ -505,7 +524,7 @@ class DepthwiseConvConverter
505
524
indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
506
525
indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
507
526
508
- if (!maybeZps ) {
527
+ if (!hasZp ) {
509
528
Value conv = rewriter
510
529
.create <linalg::DepthwiseConv2DNhwcHwcmOp>(
511
530
loc, linalgConvTy, ValueRange{input, weight},
@@ -532,8 +551,8 @@ class DepthwiseConvConverter
532
551
.getResult (0 );
533
552
rewriter.replaceOp (op, result);
534
553
} else {
535
- IntegerAttr iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
536
- IntegerAttr wZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
554
+ IntegerAttr iZp = rewriter.getI32IntegerAttr (inputZpVal );
555
+ IntegerAttr wZp = rewriter.getI32IntegerAttr (weightZpVal );
537
556
auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
538
557
auto kZpVal = rewriter.create <arith::ConstantOp>(loc, wZp);
539
558
Value conv =
0 commit comments