@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
46
46
op->getLoc (), IntegerAttr::get (requiredAttrType, castedN));
47
47
}
48
48
49
- static Value
50
- createLinalgBodyCalculationForElementwiseOp (Operation *op, ValueRange args,
51
- ArrayRef<Type> resultTypes,
52
- PatternRewriter &rewriter) {
49
+ static Value createLinalgBodyCalculationForElementwiseOp (
50
+ Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51
+ ConversionPatternRewriter &rewriter) {
53
52
Location loc = op->getLoc ();
54
53
auto elementTy =
55
54
cast<ShapedType>(op->getOperand (0 ).getType ()).getElementType ();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
186
185
Value max = rewriter.create <arith::ConstantIntOp>(
187
186
loc, APInt::getSignedMaxValue (inputBitWidth).getSExtValue (),
188
187
intermediateType);
189
- auto clamp = clampIntHelper (loc, sub, min, max, rewriter);
188
+ auto clamp =
189
+ clampIntHelper (loc, sub, min, max, rewriter, /* isUnsigned=*/ false );
190
190
191
191
// Truncate to the final value.
192
192
return rewriter.create <arith::TruncIOp>(loc, elementTy, clamp);
@@ -390,10 +390,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
390
390
cast<IntegerAttr>(op->getAttr (" max_int" )).getValue ().getSExtValue ();
391
391
392
392
if (intTy.isUnsignedInteger ()) {
393
+ if (intTy.getIntOrFloatBitWidth () > 63 ) {
394
+ (void )rewriter.notifyMatchFailure (
395
+ op, " support for 64-bit or larger integers is not implemented" );
396
+ return {};
397
+ }
393
398
min = std::max (min, (int64_t )0 );
394
- max = std::min (
395
- max,
396
- APInt::getMaxValue (intTy. getIntOrFloatBitWidth ()). getSExtValue ());
399
+ max = std::min (max,
400
+ ( int64_t ) APInt::getMaxValue (intTy. getIntOrFloatBitWidth ())
401
+ . getZExtValue ());
397
402
} else {
398
403
min =
399
404
std::max (min, APInt::getSignedMinValue (intTy.getIntOrFloatBitWidth ())
@@ -407,7 +412,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
407
412
loc, min, intTy.getIntOrFloatBitWidth ());
408
413
auto maxVal = rewriter.create <arith::ConstantIntOp>(
409
414
loc, max, intTy.getIntOrFloatBitWidth ());
410
- return clampIntHelper (loc, args[0 ], minVal, maxVal, rewriter);
415
+ return clampIntHelper (loc, args[0 ], minVal, maxVal, rewriter,
416
+ intTy.isUnsignedInteger ());
411
417
}
412
418
413
419
// tosa::SigmoidOp
@@ -615,10 +621,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
615
621
}
616
622
617
623
static SmallVector<Value> expandInputRanks (PatternRewriter &rewriter,
618
- Location loc, Operation *operation) {
619
- auto rank =
620
- cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
621
- return llvm::map_to_vector (operation->getOperands (), [&](Value operand) {
624
+ Location loc, ValueRange operands,
625
+ int64_t rank) {
626
+ return llvm::map_to_vector (operands, [&](Value operand) {
622
627
return expandRank (rewriter, loc, operand, rank);
623
628
});
624
629
}
@@ -843,11 +848,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
843
848
}
844
849
845
850
static LogicalResult
846
- emitElementwiseComputation (PatternRewriter &rewriter, Location loc,
851
+ emitElementwiseComputation (ConversionPatternRewriter &rewriter, Location loc,
847
852
Operation *operation, ValueRange operands,
848
- ArrayRef<OpFoldResult> targetShape) {
853
+ ArrayRef<OpFoldResult> targetShape,
854
+ const TypeConverter &converter) {
849
855
// Generate output tensor
850
- auto resultType = cast<RankedTensorType>(operation->getResultTypes ().front ());
856
+ auto resultType = cast_or_null<RankedTensorType>(
857
+ converter.convertType (operation->getResultTypes ().front ()));
858
+ if (!resultType) {
859
+ return rewriter.notifyMatchFailure (operation, " failed to convert type" );
860
+ }
851
861
Value outputTensor = rewriter.create <tensor::EmptyOp>(
852
862
loc, targetShape, resultType.getElementType ());
853
863
@@ -894,8 +904,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
894
904
}
895
905
896
906
static LogicalResult
897
- elementwiseMatchAndRewriteHelper (Operation *operation,
898
- PatternRewriter &rewriter) {
907
+ elementwiseMatchAndRewriteHelper (Operation *operation, ValueRange operands,
908
+ ConversionPatternRewriter &rewriter,
909
+ const TypeConverter &converter) {
899
910
900
911
// Collect op properties
901
912
assert (operation->getNumResults () == 1 && " elementwise op expects 1 result" );
@@ -908,13 +919,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
908
919
// Lower operation
909
920
IndexPool indexPool;
910
921
auto loc = operation->getLoc ();
911
- auto expandedOperands = expandInputRanks (rewriter, loc, operation);
922
+ auto rank =
923
+ cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
924
+ auto expandedOperands = expandInputRanks (rewriter, loc, operands, rank);
912
925
auto [targetShape, masterOperands] =
913
926
computeTargetShape (rewriter, loc, indexPool, expandedOperands);
914
927
auto broadcastOperands = broadcastDynamicDimensions (
915
928
rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
916
929
return emitElementwiseComputation (rewriter, loc, operation, broadcastOperands,
917
- targetShape);
930
+ targetShape, converter );
918
931
}
919
932
920
933
// Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1113,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1100
1113
namespace {
1101
1114
1102
1115
template <typename SrcOp>
1103
- class PointwiseConverter : public OpRewritePattern <SrcOp> {
1116
+ class PointwiseConverter : public OpConversionPattern <SrcOp> {
1104
1117
public:
1105
- using OpRewritePattern<SrcOp>::OpRewritePattern;
1118
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
1119
+ using typename OpConversionPattern<SrcOp>::OpAdaptor;
1106
1120
1107
- LogicalResult matchAndRewrite (SrcOp op,
1108
- PatternRewriter &rewriter) const final {
1109
- return elementwiseMatchAndRewriteHelper (op, rewriter);
1121
+ LogicalResult
1122
+ matchAndRewrite (SrcOp op, OpAdaptor operands,
1123
+ ConversionPatternRewriter &rewriter) const final {
1124
+ return elementwiseMatchAndRewriteHelper (
1125
+ op, operands.getOperands (), rewriter, *this ->getTypeConverter ());
1110
1126
}
1111
1127
};
1112
1128
@@ -1279,7 +1295,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1279
1295
loc, nestedBuilder.getI32IntegerAttr (intMax));
1280
1296
1281
1297
value = clampIntHelper (nestedLoc, value, intMinVal, intMaxVal,
1282
- nestedBuilder);
1298
+ nestedBuilder, /* isUnsigned= */ false );
1283
1299
1284
1300
if (outIntType.getWidth () < 32 ) {
1285
1301
value = nestedBuilder.create <arith::TruncIOp>(
@@ -1643,7 +1659,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1643
1659
1644
1660
auto offset = b.create <arith::SelectOp>(pred, one, zeroI32);
1645
1661
val = b.create <arith::AddIOp>(val, offset);
1646
- val = clampIntHelper (loc, val, zeroI32, max, b);
1662
+ val = clampIntHelper (loc, val, zeroI32, max, b, /* isUnsigned= */ false );
1647
1663
return b.create <arith::IndexCastOp>(b.getIndexType (), val);
1648
1664
};
1649
1665
@@ -1664,8 +1680,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1664
1680
Value max, ImplicitLocOpBuilder &b) {
1665
1681
val0 = in;
1666
1682
val1 = b.create <arith::AddIOp>(val0, oneVal);
1667
- val0 = clampIntHelper (loc, val0, zeroI32, max, b);
1668
- val1 = clampIntHelper (loc, val1, zeroI32, max, b);
1683
+ val0 =
1684
+ clampIntHelper (loc, val0, zeroI32, max, b, /* isUnsigned=*/ false );
1685
+ val1 =
1686
+ clampIntHelper (loc, val1, zeroI32, max, b, /* isUnsigned=*/ false );
1669
1687
val0 = b.create <arith::IndexCastOp>(b.getIndexType (), val0);
1670
1688
val1 = b.create <arith::IndexCastOp>(b.getIndexType (), val1);
1671
1689
};
@@ -2552,7 +2570,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2552
2570
} // namespace
2553
2571
2554
2572
void mlir::tosa::populateTosaToLinalgConversionPatterns (
2555
- RewritePatternSet *patterns) {
2573
+ TypeConverter &converter, RewritePatternSet *patterns) {
2556
2574
2557
2575
// We have multiple resize coverters to handle degenerate cases.
2558
2576
patterns->add <GenericResizeConverter>(patterns->getContext (),
@@ -2599,7 +2617,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
2599
2617
PointwiseConverter<tosa::CeilOp>,
2600
2618
PointwiseConverter<tosa::FloorOp>,
2601
2619
PointwiseConverter<tosa::ClampOp>,
2602
- PointwiseConverter<tosa::SigmoidOp>,
2620
+ PointwiseConverter<tosa::SigmoidOp>
2621
+ >(converter, patterns->getContext ());
2622
+
2623
+ patterns->add <
2603
2624
IdentityNConverter<tosa::IdentityOp>,
2604
2625
ReduceConverter<tosa::ReduceAllOp>,
2605
2626
ReduceConverter<tosa::ReduceAnyOp>,
0 commit comments