@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
46
46
op->getLoc (), IntegerAttr::get (requiredAttrType, castedN));
47
47
}
48
48
49
- static Value
50
- createLinalgBodyCalculationForElementwiseOp (Operation *op, ValueRange args,
51
- ArrayRef<Type> resultTypes,
52
- PatternRewriter &rewriter) {
49
+ static Value createLinalgBodyCalculationForElementwiseOp (
50
+ Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51
+ ConversionPatternRewriter &rewriter) {
53
52
Location loc = op->getLoc ();
54
53
auto elementTy =
55
54
cast<ShapedType>(op->getOperand (0 ).getType ()).getElementType ();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
186
185
Value max = rewriter.create <arith::ConstantIntOp>(
187
186
loc, APInt::getSignedMaxValue (inputBitWidth).getSExtValue (),
188
187
intermediateType);
189
- auto clamp = clampIntHelper (loc, sub, min, max, rewriter);
188
+ auto clamp =
189
+ clampIntHelper (loc, sub, min, max, rewriter, /* isUnsigned=*/ false );
190
190
191
191
// Truncate to the final value.
192
192
return rewriter.create <arith::TruncIOp>(loc, elementTy, clamp);
@@ -389,25 +389,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
389
389
int64_t max =
390
390
cast<IntegerAttr>(op->getAttr (" max_int" )).getValue ().getSExtValue ();
391
391
392
+ int64_t minRepresentable = std::numeric_limits<int64_t >::min ();
393
+ int64_t maxRepresentable = std::numeric_limits<int64_t >::max ();
392
394
if (intTy.isUnsignedInteger ()) {
393
- min = std::max (min, ( int64_t ) 0 ) ;
394
- max = std::min (
395
- max,
396
- APInt::getMaxValue (intTy. getIntOrFloatBitWidth ()). getSExtValue () );
397
- } else {
398
- min =
399
- std:: max(min, APInt::getSignedMinValue (intTy. getIntOrFloatBitWidth ())
400
- . getSExtValue ());
401
- max =
402
- std::min (max, APInt::getSignedMaxValue (intTy.getIntOrFloatBitWidth ())
403
- .getSExtValue ()) ;
395
+ minRepresentable = 0 ;
396
+ if (intTy. getIntOrFloatBitWidth () <= 63 ) {
397
+ maxRepresentable = ( int64_t ) APInt::getMaxValue (intTy. getIntOrFloatBitWidth ())
398
+ . getZExtValue ( );
399
+ }
400
+ } else if (intTy. getIntOrFloatBitWidth () <= 64 ) {
401
+ // Ensure that min & max fit into signed n-bit constants.
402
+ minRepresentable = APInt::getSignedMinValue (intTy. getIntOrFloatBitWidth ())
403
+ . getSExtValue ();
404
+ maxRepresentable = APInt::getSignedMaxValue (intTy.getIntOrFloatBitWidth ())
405
+ .getSExtValue ();
404
406
}
407
+ // Ensure that the bounds are representable as n-bit signed/unsigned integers.
408
+ min = std::max (min, minRepresentable);
409
+ max = std::max (max, minRepresentable);
410
+ min = std::min (min, maxRepresentable);
411
+ max = std::min (max, maxRepresentable);
405
412
406
413
auto minVal = rewriter.create <arith::ConstantIntOp>(
407
414
loc, min, intTy.getIntOrFloatBitWidth ());
408
415
auto maxVal = rewriter.create <arith::ConstantIntOp>(
409
416
loc, max, intTy.getIntOrFloatBitWidth ());
410
- return clampIntHelper (loc, args[0 ], minVal, maxVal, rewriter);
417
+ return clampIntHelper (loc, args[0 ], minVal, maxVal, rewriter,
418
+ intTy.isUnsignedInteger ());
411
419
}
412
420
413
421
// tosa::SigmoidOp
@@ -615,10 +623,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
615
623
}
616
624
617
625
static SmallVector<Value> expandInputRanks (PatternRewriter &rewriter,
618
- Location loc, Operation *operation) {
619
- auto rank =
620
- cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
621
- return llvm::map_to_vector (operation->getOperands (), [&](Value operand) {
626
+ Location loc, ValueRange operands,
627
+ int64_t rank) {
628
+ return llvm::map_to_vector (operands, [&](Value operand) {
622
629
return expandRank (rewriter, loc, operand, rank);
623
630
});
624
631
}
@@ -843,11 +850,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
843
850
}
844
851
845
852
static LogicalResult
846
- emitElementwiseComputation (PatternRewriter &rewriter, Location loc,
853
+ emitElementwiseComputation (ConversionPatternRewriter &rewriter, Location loc,
847
854
Operation *operation, ValueRange operands,
848
- ArrayRef<OpFoldResult> targetShape) {
855
+ ArrayRef<OpFoldResult> targetShape,
856
+ const TypeConverter &converter) {
849
857
// Generate output tensor
850
- auto resultType = cast<RankedTensorType>(operation->getResultTypes ().front ());
858
+ auto resultType = cast_or_null<RankedTensorType>(
859
+ converter.convertType (operation->getResultTypes ().front ()));
860
+ if (!resultType) {
861
+ return rewriter.notifyMatchFailure (operation, " failed to convert type" );
862
+ }
851
863
Value outputTensor = rewriter.create <tensor::EmptyOp>(
852
864
loc, targetShape, resultType.getElementType ());
853
865
@@ -894,8 +906,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
894
906
}
895
907
896
908
static LogicalResult
897
- elementwiseMatchAndRewriteHelper (Operation *operation,
898
- PatternRewriter &rewriter) {
909
+ elementwiseMatchAndRewriteHelper (Operation *operation, ValueRange operands,
910
+ ConversionPatternRewriter &rewriter,
911
+ const TypeConverter &converter) {
899
912
900
913
// Collect op properties
901
914
assert (operation->getNumResults () == 1 && " elementwise op expects 1 result" );
@@ -908,13 +921,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
908
921
// Lower operation
909
922
IndexPool indexPool;
910
923
auto loc = operation->getLoc ();
911
- auto expandedOperands = expandInputRanks (rewriter, loc, operation);
924
+ auto rank =
925
+ cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
926
+ auto expandedOperands = expandInputRanks (rewriter, loc, operands, rank);
912
927
auto [targetShape, masterOperands] =
913
928
computeTargetShape (rewriter, loc, indexPool, expandedOperands);
914
929
auto broadcastOperands = broadcastDynamicDimensions (
915
930
rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
916
931
return emitElementwiseComputation (rewriter, loc, operation, broadcastOperands,
917
- targetShape);
932
+ targetShape, converter );
918
933
}
919
934
920
935
// Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1115,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1100
1115
namespace {
1101
1116
1102
1117
template <typename SrcOp>
1103
- class PointwiseConverter : public OpRewritePattern <SrcOp> {
1118
+ class PointwiseConverter : public OpConversionPattern <SrcOp> {
1104
1119
public:
1105
- using OpRewritePattern<SrcOp>::OpRewritePattern;
1120
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
1121
+ using typename OpConversionPattern<SrcOp>::OpAdaptor;
1106
1122
1107
- LogicalResult matchAndRewrite (SrcOp op,
1108
- PatternRewriter &rewriter) const final {
1109
- return elementwiseMatchAndRewriteHelper (op, rewriter);
1123
+ LogicalResult
1124
+ matchAndRewrite (SrcOp op, OpAdaptor operands,
1125
+ ConversionPatternRewriter &rewriter) const final {
1126
+ return elementwiseMatchAndRewriteHelper (
1127
+ op, operands.getOperands (), rewriter, *this ->getTypeConverter ());
1110
1128
}
1111
1129
};
1112
1130
@@ -1279,7 +1297,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1279
1297
loc, nestedBuilder.getI32IntegerAttr (intMax));
1280
1298
1281
1299
value = clampIntHelper (nestedLoc, value, intMinVal, intMaxVal,
1282
- nestedBuilder);
1300
+ nestedBuilder, /* isUnsigned= */ false );
1283
1301
1284
1302
if (outIntType.getWidth () < 32 ) {
1285
1303
value = nestedBuilder.create <arith::TruncIOp>(
@@ -1643,7 +1661,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1643
1661
1644
1662
auto offset = b.create <arith::SelectOp>(pred, one, zeroI32);
1645
1663
val = b.create <arith::AddIOp>(val, offset);
1646
- val = clampIntHelper (loc, val, zeroI32, max, b);
1664
+ val = clampIntHelper (loc, val, zeroI32, max, b, /* isUnsigned= */ false );
1647
1665
return b.create <arith::IndexCastOp>(b.getIndexType (), val);
1648
1666
};
1649
1667
@@ -1664,8 +1682,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1664
1682
Value max, ImplicitLocOpBuilder &b) {
1665
1683
val0 = in;
1666
1684
val1 = b.create <arith::AddIOp>(val0, oneVal);
1667
- val0 = clampIntHelper (loc, val0, zeroI32, max, b);
1668
- val1 = clampIntHelper (loc, val1, zeroI32, max, b);
1685
+ val0 =
1686
+ clampIntHelper (loc, val0, zeroI32, max, b, /* isUnsigned=*/ false );
1687
+ val1 =
1688
+ clampIntHelper (loc, val1, zeroI32, max, b, /* isUnsigned=*/ false );
1669
1689
val0 = b.create <arith::IndexCastOp>(b.getIndexType (), val0);
1670
1690
val1 = b.create <arith::IndexCastOp>(b.getIndexType (), val1);
1671
1691
};
@@ -2555,7 +2575,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2555
2575
} // namespace
2556
2576
2557
2577
void mlir::tosa::populateTosaToLinalgConversionPatterns (
2558
- RewritePatternSet *patterns) {
2578
+ TypeConverter &converter, RewritePatternSet *patterns) {
2559
2579
2560
2580
// We have multiple resize coverters to handle degenerate cases.
2561
2581
patterns->add <GenericResizeConverter>(patterns->getContext (),
@@ -2602,7 +2622,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
2602
2622
PointwiseConverter<tosa::CeilOp>,
2603
2623
PointwiseConverter<tosa::FloorOp>,
2604
2624
PointwiseConverter<tosa::ClampOp>,
2605
- PointwiseConverter<tosa::SigmoidOp>,
2625
+ PointwiseConverter<tosa::SigmoidOp>
2626
+ >(converter, patterns->getContext ());
2627
+
2628
+ patterns->add <
2606
2629
IdentityNConverter<tosa::IdentityOp>,
2607
2630
ReduceConverter<tosa::ReduceAllOp>,
2608
2631
ReduceConverter<tosa::ReduceAnyOp>,
0 commit comments