Skip to content

Commit b0b5d2f

Browse files
committed
TorchToLinAlg: fix tosa.clamp legalization for integer types.
1 parent 9bccb5b commit b0b5d2f

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -381,23 +381,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
381381

382382
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
383383
auto intTy = elementTy.cast<IntegerType>();
384-
int32_t min = static_cast<int32_t>(
385-
op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
386-
int32_t max = static_cast<int32_t>(
387-
op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
384+
int64_t min =
385+
op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue();
386+
int64_t max =
387+
op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue();
388388

389389
if (intTy.isUnsignedInteger()) {
390-
min = std::max<int32_t>(min, 0);
391-
max = std::min<int32_t>(
390+
min = std::max(min, (int64_t)0);
391+
max = std::min(
392392
max,
393393
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
394394
} else {
395-
min = std::max<int32_t>(
396-
min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
397-
.getSExtValue());
398-
max = std::min<int32_t>(
399-
max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
400-
.getSExtValue());
395+
min =
396+
std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
397+
.getSExtValue());
398+
max =
399+
std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
400+
.getSExtValue());
401401
}
402402

403403
auto minVal = rewriter.create<arith::ConstantIntOp>(

0 commit comments

Comments
 (0)