@@ -339,33 +339,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
339
339
}
340
340
};
341
341
342
+ // Attempts the following transformation:
343
+ //
344
+ // For integers a, b, a', and b' such that [a, b] ∩ [c, d] ≠ ∅ and input
345
+ // tensor X the following identity holds:
346
+ //
347
+ // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
348
+ //
349
+ // subject to the following valid NaN propagation semantics:
350
+ // --------------------------------------------
351
+ // | opNanMode | clampNanMode | resultNanMode |
352
+ // |-----------|--------------|---------------|
353
+ // | PROPAGATE | PROPAGATE | PROPAGATE |
354
+ // | PROPAGATE | IGNORE | IGNORE |
355
+ // | IGNORE | PROPAGATE | INVALID |
356
+ // | IGNORE | IGNORE | INGORE |
357
+ // |------------------------------------------|
358
+
342
359
struct ClampClampOptimization : public OpRewritePattern <tosa::ClampOp> {
343
360
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
344
361
362
+ // Helper structure to describe the range of a clamp operation.
363
+ template <typename T>
364
+ struct ClampRange {
365
+ ClampRange (const T &start, const T &end) : start(start), end(end) {}
366
+ T start;
367
+ T end;
368
+
369
+ // Helper function to determine if two Clamp ranges intersect.
370
+ bool intersects (const ClampRange<T> &otherRange) {
371
+ return start < otherRange.end && otherRange.start < end;
372
+ }
373
+ };
374
+
345
375
LogicalResult matchAndRewrite (tosa::ClampOp op,
346
376
PatternRewriter &rewriter) const override {
347
- Value input = op. getInput ();
348
-
349
- Operation *definingOp = input. getDefiningOp ();
350
- if (!definingOp )
377
+ // Check the input to the CLAMP op is itself a CLAMP.
378
+ auto clampOp =
379
+ dyn_cast_if_present<tosa::ClampOp>(op. getInput (). getDefiningOp () );
380
+ if (!clampOp )
351
381
return failure ();
352
382
353
- if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354
- auto minFp = std::max (op.getMinFp (), clampOp.getMinFp ()).convertToFloat ();
355
- auto maxFp = std::min (op.getMaxFp (), clampOp.getMaxFp ()).convertToFloat ();
383
+ // Check we have a valid NaN propagation combination.
384
+ const auto opNanMode = op.getNanMode ();
385
+ const auto clampNanMode = clampOp.getNanMode ();
386
+ if (opNanMode == " IGNORE" && clampNanMode == " PROPAGATE" )
387
+ return failure ();
356
388
357
- auto minInt = std::max (op.getMinInt (), clampOp.getMinInt ());
358
- auto maxInt = std::min (op.getMaxInt (), clampOp.getMaxInt ());
389
+ // Check we have intersecting ranges.
390
+ const auto opMinInt = op.getMinInt ();
391
+ const auto opMaxInt = op.getMaxInt ();
392
+ const auto clampOpMinInt = clampOp.getMinInt ();
393
+ const auto clampOpMaxInt = clampOp.getMaxInt ();
394
+ ClampRange<std::int64_t > opRangeIntRange (opMinInt, opMaxInt);
395
+ ClampRange<std::int64_t > clampRangeIntRange (clampOpMinInt, clampOpMaxInt);
396
+ if (!opRangeIntRange.intersects (clampRangeIntRange))
397
+ return failure ();
359
398
360
- rewriter.replaceOpWithNewOp <tosa::ClampOp>(
361
- op, op.getType (), clampOp.getInput (),
362
- rewriter.getI64IntegerAttr (minInt),
363
- rewriter.getI64IntegerAttr (maxInt), rewriter.getF32FloatAttr (minFp),
364
- rewriter.getF32FloatAttr (maxFp));
365
- return success ();
366
- }
399
+ const auto opMinFloat = op.getMinFp ();
400
+ const auto opMaxFloat = op.getMaxFp ();
401
+ const auto clampOpMinFloat = clampOp.getMinFp ();
402
+ const auto clampOpMaxFloat = clampOp.getMaxFp ();
403
+ ClampRange opRangeFloatRange (opMinFloat, opMaxFloat);
404
+ ClampRange clampRangeFloatRange (clampOpMinFloat, clampOpMaxFloat);
405
+ if (!opRangeFloatRange.intersects (clampRangeFloatRange))
406
+ return failure ();
367
407
368
- return failure ();
408
+ // Run the transformation.
409
+ const auto minFp = std::max (opMinFloat, clampOpMinFloat).convertToFloat ();
410
+ const auto maxFp = std::min (opMaxFloat, clampOpMaxFloat).convertToFloat ();
411
+ const auto minInt = std::max (opMinInt, clampOpMinInt);
412
+ const auto maxInt = std::min (opMaxInt, clampOpMaxInt);
413
+ rewriter.replaceOpWithNewOp <tosa::ClampOp>(
414
+ op, op.getType (), clampOp.getInput (),
415
+ rewriter.getI64IntegerAttr (minInt), rewriter.getI64IntegerAttr (maxInt),
416
+ rewriter.getF32FloatAttr (minFp), rewriter.getF32FloatAttr (maxFp),
417
+ rewriter.getStringAttr ((opNanMode != clampNanMode) ? " IGNORE"
418
+ : opNanMode));
419
+ return success ();
369
420
}
370
421
};
371
422
0 commit comments