@@ -287,10 +287,12 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
287
287
288
288
if (isa<FloatType>(inputElementType)) {
289
289
// Unlike integer types, floating point types can represent infinity.
290
- auto minClamp = op.getMinFp ();
291
- auto maxClamp = op.getMaxFp ();
292
- bool isMin = minClamp.isInfinity () && minClamp.isNegative ();
293
- bool isMax = maxClamp.isInfinity () && !maxClamp.isNegative ();
290
+ auto minClamp =
291
+ llvm::cast<mlir::FloatAttr>(op.getMinValAttr ()).getValue ();
292
+ auto maxClamp =
293
+ llvm::cast<mlir::FloatAttr>(op.getMaxValAttr ()).getValue ();
294
+ bool isMin = minClamp.isNegInfinity ();
295
+ bool isMax = maxClamp.isInfinity ();
294
296
295
297
if (isMin && isMax) {
296
298
rewriter.replaceOp (op, input);
@@ -300,8 +302,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
300
302
}
301
303
302
304
if (inputElementType.isUnsignedInteger ()) {
303
- int64_t minClamp = op.getMinInt ();
304
- int64_t maxClamp = op.getMaxInt ();
305
+ int64_t minClamp =
306
+ llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getUInt ();
307
+ int64_t maxClamp =
308
+ llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getUInt ();
305
309
306
310
int64_t intMin =
307
311
APInt::getMinValue (inputElementType.getIntOrFloatBitWidth ())
@@ -318,8 +322,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
318
322
}
319
323
320
324
if (llvm::isa<IntegerType>(inputElementType)) {
321
- int64_t minClamp = op.getMinInt ();
322
- int64_t maxClamp = op.getMaxInt ();
325
+ int64_t minClamp =
326
+ llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getInt ();
327
+ int64_t maxClamp =
328
+ llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getInt ();
323
329
324
330
int64_t intMin =
325
331
APInt::getSignedMinValue (inputElementType.getIntOrFloatBitWidth ())
@@ -374,9 +380,10 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
374
380
375
381
LogicalResult matchAndRewrite (tosa::ClampOp op,
376
382
PatternRewriter &rewriter) const override {
383
+ Value input = op.getInput ();
384
+
377
385
// Check the input to the CLAMP op is itself a CLAMP.
378
- auto clampOp =
379
- dyn_cast_if_present<tosa::ClampOp>(op.getInput ().getDefiningOp ());
386
+ auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp ());
380
387
if (!clampOp)
381
388
return failure ();
382
389
@@ -386,34 +393,86 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
386
393
if (opNanMode == " IGNORE" && clampNanMode == " PROPAGATE" )
387
394
return failure ();
388
395
389
- // Check we have intersecting ranges.
390
- const auto opMinInt = op.getMinInt ();
391
- const auto opMaxInt = op.getMaxInt ();
392
- const auto clampOpMinInt = clampOp.getMinInt ();
393
- const auto clampOpMaxInt = clampOp.getMaxInt ();
394
- ClampRange<std::int64_t > opRangeIntRange (opMinInt, opMaxInt);
395
- ClampRange<std::int64_t > clampRangeIntRange (clampOpMinInt, clampOpMaxInt);
396
- if (!opRangeIntRange.intersects (clampRangeIntRange))
397
- return failure ();
396
+ auto maxValAttr = op.getMaxValAttr ();
397
+ auto minValAttr = op.getMinValAttr ();
398
+ auto clampOpMaxValAttr = clampOp.getMaxValAttr ();
399
+ auto clampOpMinValAttr = clampOp.getMinValAttr ();
398
400
399
- const auto opMinFloat = op.getMinFp ();
400
- const auto opMaxFloat = op.getMaxFp ();
401
- const auto clampOpMinFloat = clampOp.getMinFp ();
402
- const auto clampOpMaxFloat = clampOp.getMaxFp ();
403
- ClampRange<APFloat> opRangeFloatRange (opMinFloat, opMaxFloat);
404
- ClampRange<APFloat> clampRangeFloatRange (clampOpMinFloat, clampOpMaxFloat);
405
- if (!opRangeFloatRange.intersects (clampRangeFloatRange))
406
- return failure ();
401
+ auto inputEType = llvm::cast<ShapedType>(input.getType ()).getElementType ();
402
+ if (auto quantType =
403
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
404
+ inputEType = quantType.getStorageType ();
405
+ }
406
+
407
+ Attribute newMinValAttr, newMaxValAttr;
408
+ if (mlir::isa<FloatType>(inputEType)) {
409
+ auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
410
+ auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
411
+ auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
412
+ auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
413
+
414
+ // Check we have intersecting ranges.
415
+ const auto opMinFloat = floatMinValAttr.getValue ();
416
+ const auto opMaxFloat = floatMaxValAttr.getValue ();
417
+ const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue ();
418
+ const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue ();
419
+ ClampRange<APFloat> opRangeFloatRange (opMinFloat, opMaxFloat);
420
+ ClampRange<APFloat> clampRangeFloatRange (clampOpMinFloat,
421
+ clampOpMaxFloat);
422
+ if (!opRangeFloatRange.intersects (clampRangeFloatRange))
423
+ return failure ();
424
+
425
+ // Run the transformation.
426
+ auto newMinVal = std::max (opMinFloat, clampOpMinFloat);
427
+ auto newMaxVal = std::min (opMaxFloat, clampOpMaxFloat);
428
+ newMinValAttr = rewriter.getFloatAttr (inputEType, newMinVal);
429
+ newMaxValAttr = rewriter.getFloatAttr (inputEType, newMaxVal);
430
+ } else {
431
+ assert (mlir::isa<IntegerType>(inputEType));
432
+ auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
433
+ auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
434
+ auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
435
+ auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
436
+
437
+ if (inputEType.isUnsignedInteger ()) {
438
+ // Check we have intersecting ranges.
439
+ const auto opMinInt = intMinValAttr.getUInt ();
440
+ const auto opMaxInt = intMaxValAttr.getUInt ();
441
+ const auto clampOpMinInt = clampOpIntMinValAttr.getUInt ();
442
+ const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt ();
443
+ ClampRange<std::uint64_t > opRangeIntRange (opMinInt, opMaxInt);
444
+ ClampRange<std::uint64_t > clampRangeIntRange (clampOpMinInt,
445
+ clampOpMaxInt);
446
+ if (!opRangeIntRange.intersects (clampRangeIntRange))
447
+ return failure ();
448
+
449
+ // Run the transformation.
450
+ auto newMinVal = std::max (opMinInt, clampOpMinInt);
451
+ auto newMaxVal = std::min (opMaxInt, clampOpMaxInt);
452
+ newMinValAttr = rewriter.getIntegerAttr (inputEType, newMinVal);
453
+ newMaxValAttr = rewriter.getIntegerAttr (inputEType, newMaxVal);
454
+ } else {
455
+ // Check we have intersecting ranges.
456
+ const auto opMinInt = intMinValAttr.getInt ();
457
+ const auto opMaxInt = intMaxValAttr.getInt ();
458
+ const auto clampOpMinInt = clampOpIntMinValAttr.getInt ();
459
+ const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt ();
460
+ ClampRange<std::int64_t > opRangeIntRange (opMinInt, opMaxInt);
461
+ ClampRange<std::int64_t > clampRangeIntRange (clampOpMinInt,
462
+ clampOpMaxInt);
463
+ if (!opRangeIntRange.intersects (clampRangeIntRange))
464
+ return failure ();
465
+
466
+ // Run the transformation.
467
+ auto newMinVal = std::max (opMinInt, clampOpMinInt);
468
+ auto newMaxVal = std::min (opMaxInt, clampOpMaxInt);
469
+ newMinValAttr = rewriter.getIntegerAttr (inputEType, newMinVal);
470
+ newMaxValAttr = rewriter.getIntegerAttr (inputEType, newMaxVal);
471
+ }
472
+ }
407
473
408
- // Run the transformation.
409
- const auto minFp = std::max (opMinFloat, clampOpMinFloat).convertToFloat ();
410
- const auto maxFp = std::min (opMaxFloat, clampOpMaxFloat).convertToFloat ();
411
- const auto minInt = std::max (opMinInt, clampOpMinInt);
412
- const auto maxInt = std::min (opMaxInt, clampOpMaxInt);
413
474
rewriter.replaceOpWithNewOp <tosa::ClampOp>(
414
- op, op.getType (), clampOp.getInput (),
415
- rewriter.getI64IntegerAttr (minInt), rewriter.getI64IntegerAttr (maxInt),
416
- rewriter.getF32FloatAttr (minFp), rewriter.getF32FloatAttr (maxFp),
475
+ op, op.getType (), clampOp.getInput (), newMinValAttr, newMaxValAttr,
417
476
rewriter.getStringAttr ((opNanMode != clampNanMode) ? " IGNORE"
418
477
: opNanMode));
419
478
return success ();
0 commit comments