@@ -52,6 +52,23 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
52
52
}
53
53
};
54
54
55
+ Type matchContainerType (Type element, Type container) {
56
+ if (auto shapedTy = container.dyn_cast <ShapedType>())
57
+ return shapedTy.clone (element);
58
+
59
+ return element;
60
+ }
61
+
62
+ Attribute getConstantAttr (Type type, int64_t value, PatternRewriter &rewriter) {
63
+ if (auto shapedTy = type.dyn_cast <ShapedType>()) {
64
+ Type eTy = shapedTy.getElementType ();
65
+ APInt valueInt (eTy.getIntOrFloatBitWidth (), value);
66
+ return DenseIntElementsAttr::get (shapedTy, valueInt);
67
+ }
68
+
69
+ return rewriter.getIntegerAttr (type, value);
70
+ }
71
+
55
72
// This converts the TOSA ApplyScale operator to a set of StandardOps ops,
56
73
// using 64-bit operations to perform the necessary multiply, bias, and shift.
57
74
// Multiple types are used to use minimal bit width operations.
@@ -65,13 +82,19 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
65
82
Value value32 = op.value ();
66
83
Value multiplier32 = op.multiplier ();
67
84
Value shift8 = op.shift ();
85
+
68
86
bool doubleRound = op.double_round ();
69
87
Type inType = op.value ().getType ();
88
+ Type resultTy = op.getType ();
89
+
90
+ Type i8Ty = matchContainerType (rewriter.getIntegerType (8 ), resultTy);
91
+ Type i32Ty = matchContainerType (rewriter.getI32Type (), resultTy);
92
+ Type i64Ty = matchContainerType (rewriter.getI64Type (), resultTy);
70
93
71
94
Value one8 = rewriter.create <arith::ConstantOp>(
72
- loc, rewriter. getIntegerAttr (rewriter. getIntegerType ( 8 ) , 1 ));
95
+ loc, getConstantAttr (i8Ty , 1 , rewriter ));
73
96
Value one64 = rewriter.create <arith::ConstantOp>(
74
- loc, rewriter. getIntegerAttr (rewriter. getI64Type () , 1 ));
97
+ loc, getConstantAttr (i64Ty , 1 , rewriter ));
75
98
76
99
Value shiftSubOne8 = rewriter.create <arith::SubIOp>(loc, shift8, one8);
77
100
@@ -85,23 +108,20 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
85
108
// Note that minimal bitwidth operators are used throughout the block.
86
109
87
110
Value round64 = rewriter.create <arith::ShLIOp>(
88
- loc, one64,
89
- rewriter.create <arith::ExtSIOp>(loc, rewriter.getI64Type (),
90
- shiftSubOne8));
111
+ loc, one64, rewriter.create <arith::ExtSIOp>(loc, i64Ty, shiftSubOne8));
91
112
92
113
// Double rounding is performing a round operation before the shift
93
114
if (doubleRound) {
94
115
Value one32 = rewriter.create <arith::ConstantOp>(
95
- loc, rewriter.getIntegerAttr (rewriter.getI32Type (), 1 ));
96
- Value shift32 =
97
- rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), shift8);
116
+ loc, getConstantAttr (i32Ty, 1 , rewriter));
117
+ Value shift32 = rewriter.create <arith::ExtSIOp>(loc, i32Ty, shift8);
98
118
Value thirty32 = rewriter.create <arith::ConstantOp>(
99
- loc, rewriter. getIntegerAttr (rewriter. getI32Type () , 30 ));
119
+ loc, getConstantAttr (i32Ty , 30 , rewriter ));
100
120
101
121
Value shiftThirty32 =
102
122
rewriter.create <arith::ShLIOp>(loc, one32, thirty32);
103
- Value shiftThirty64 = rewriter. create <arith::ExtSIOp>(
104
- loc, rewriter.getI64Type () , shiftThirty32);
123
+ Value shiftThirty64 =
124
+ rewriter.create <arith::ExtSIOp>(loc, i64Ty , shiftThirty32);
105
125
106
126
// Round value needs to with be added or subtracted depending on the sign
107
127
// of the input value.
@@ -120,7 +140,7 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
120
140
121
141
// We only perform double rounding if the shift value is greater than 32.
122
142
Value thirtyTwo32 = rewriter.create <arith::ConstantOp>(
123
- loc, rewriter. getIntegerAttr (rewriter. getI32Type () , 32 ));
143
+ loc, getConstantAttr (i32Ty , 32 , rewriter ));
124
144
Value shiftGreaterThanThirtyTwo = rewriter.create <arith::CmpIOp>(
125
145
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
126
146
round64 = rewriter.create <mlir::SelectOp>(loc, shiftGreaterThanThirtyTwo,
@@ -133,20 +153,17 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
133
153
//
134
154
// Note that multiply and shift need to be perform in i64 to preserve bits.
135
155
136
- Value value64 =
137
- rewriter.create <arith::ExtSIOp>(loc, rewriter.getI64Type (), value32);
138
- Value multiplier64 = rewriter.create <arith::ExtSIOp>(
139
- loc, rewriter.getI64Type (), multiplier32);
140
- Value shift64 =
141
- rewriter.create <arith::ExtSIOp>(loc, rewriter.getI64Type (), shift8);
156
+ Value value64 = rewriter.create <arith::ExtSIOp>(loc, i64Ty, value32);
157
+ Value multiplier64 =
158
+ rewriter.create <arith::ExtSIOp>(loc, i64Ty, multiplier32);
159
+ Value shift64 = rewriter.create <arith::ExtSIOp>(loc, i64Ty, shift8);
142
160
143
161
// Multiply as a pair of i64 values to guarantee the end value fits.
144
162
Value result64 = rewriter.create <arith::MulIOp>(loc, value64, multiplier64);
145
163
result64 = rewriter.create <arith::AddIOp>(loc, result64, round64);
146
164
result64 = rewriter.create <arith::ShRSIOp>(loc, result64, shift64);
147
165
148
- Value result32 =
149
- rewriter.create <arith::TruncIOp>(loc, rewriter.getI32Type (), result64);
166
+ Value result32 = rewriter.create <arith::TruncIOp>(loc, resultTy, result64);
150
167
151
168
rewriter.replaceOp (op, result32);
152
169
return success ();
0 commit comments