@@ -122,6 +122,200 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
122
122
results.add <CombineChainedAccessChain>(context);
123
123
}
124
124
125
+ // ===----------------------------------------------------------------------===//
126
+ // spirv.IAddCarry
127
+ // ===----------------------------------------------------------------------===//
128
+
129
+ // We are required to use CompositeConstructOp to create a constant struct as
130
+ // they are not yet implemented as constant, hence we can not do so in a fold.
131
+ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
132
+ using OpRewritePattern::OpRewritePattern;
133
+
134
+ LogicalResult matchAndRewrite (spirv::IAddCarryOp op,
135
+ PatternRewriter &rewriter) const override {
136
+ Location loc = op.getLoc ();
137
+ Value lhs = op.getOperand1 ();
138
+ Value rhs = op.getOperand2 ();
139
+ Type constituentType = lhs.getType ();
140
+
141
+ // iaddcarry (x, 0) = <0, x>
142
+ if (matchPattern (rhs, m_Zero ())) {
143
+ Value constituents[2 ] = {rhs, lhs};
144
+ rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(op, op.getType (),
145
+ constituents);
146
+ return success ();
147
+ }
148
+
149
+ // According to the SPIR-V spec:
150
+ //
151
+ // Result Type must be from OpTypeStruct. The struct must have two
152
+ // members...
153
+ //
154
+ // Member 0 of the result gets the low-order bits (full component width) of
155
+ // the addition.
156
+ //
157
+ // Member 1 of the result gets the high-order (carry) bit of the result of
158
+ // the addition. That is, it gets the value 1 if the addition overflowed
159
+ // the component width, and 0 otherwise.
160
+ Attribute lhsAttr;
161
+ Attribute rhsAttr;
162
+ if (!matchPattern (lhs, m_Constant (&lhsAttr)) ||
163
+ !matchPattern (rhs, m_Constant (&rhsAttr)))
164
+ return failure ();
165
+
166
+ auto adds = constFoldBinaryOp<IntegerAttr>(
167
+ {lhsAttr, rhsAttr},
168
+ [](const APInt &a, const APInt &b) { return a + b; });
169
+ if (!adds)
170
+ return failure ();
171
+
172
+ auto carrys = constFoldBinaryOp<IntegerAttr>(
173
+ ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
174
+ APInt zero = APInt::getZero (a.getBitWidth ());
175
+ return a.ult (b) ? (zero + 1 ) : zero;
176
+ });
177
+
178
+ if (!carrys)
179
+ return failure ();
180
+
181
+ Value addsVal =
182
+ rewriter.create <spirv::ConstantOp>(loc, constituentType, adds);
183
+
184
+ Value carrysVal =
185
+ rewriter.create <spirv::ConstantOp>(loc, constituentType, carrys);
186
+
187
+ // Create empty struct
188
+ Value undef = rewriter.create <spirv::UndefOp>(loc, op.getType ());
189
+ // Fill in adds at id 0
190
+ Value intermediate =
191
+ rewriter.create <spirv::CompositeInsertOp>(loc, addsVal, undef, 0 );
192
+ // Fill in carrys at id 1
193
+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(op, carrysVal,
194
+ intermediate, 1 );
195
+ return success ();
196
+ }
197
+ };
198
+
199
+ void spirv::IAddCarryOp::getCanonicalizationPatterns (
200
+ RewritePatternSet &patterns, MLIRContext *context) {
201
+ patterns.add <IAddCarryFold>(context);
202
+ }
203
+
204
+ // ===----------------------------------------------------------------------===//
205
+ // spirv.[S|U]MulExtended
206
+ // ===----------------------------------------------------------------------===//
207
+
208
+ // We are required to use CompositeConstructOp to create a constant struct as
209
+ // they are not yet implemented as constant, hence we can not do so in a fold.
210
+ template <typename MulOp, bool IsSigned>
211
+ struct MulExtendedFold final : OpRewritePattern<MulOp> {
212
+ using OpRewritePattern<MulOp>::OpRewritePattern;
213
+
214
+ LogicalResult matchAndRewrite (MulOp op,
215
+ PatternRewriter &rewriter) const override {
216
+ Location loc = op.getLoc ();
217
+ Value lhs = op.getOperand1 ();
218
+ Value rhs = op.getOperand2 ();
219
+ Type constituentType = lhs.getType ();
220
+
221
+ // [su]mulextended (x, 0) = <0, 0>
222
+ if (matchPattern (rhs, m_Zero ())) {
223
+ Value zero = spirv::ConstantOp::getZero (constituentType, loc, rewriter);
224
+ Value constituents[2 ] = {zero, zero};
225
+ rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(op, op.getType (),
226
+ constituents);
227
+ return success ();
228
+ }
229
+
230
+ // According to the SPIR-V spec:
231
+ //
232
+ // Result Type must be from OpTypeStruct. The struct must have two
233
+ // members...
234
+ //
235
+ // Member 0 of the result gets the low-order bits of the multiplication.
236
+ //
237
+ // Member 1 of the result gets the high-order bits of the multiplication.
238
+ Attribute lhsAttr;
239
+ Attribute rhsAttr;
240
+ if (!matchPattern (lhs, m_Constant (&lhsAttr)) ||
241
+ !matchPattern (rhs, m_Constant (&rhsAttr)))
242
+ return failure ();
243
+
244
+ auto lowBits = constFoldBinaryOp<IntegerAttr>(
245
+ {lhsAttr, rhsAttr},
246
+ [](const APInt &a, const APInt &b) { return a * b; });
247
+
248
+ if (!lowBits)
249
+ return failure ();
250
+
251
+ auto highBits = constFoldBinaryOp<IntegerAttr>(
252
+ {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253
+ unsigned bitWidth = a.getBitWidth ();
254
+ APInt c;
255
+ if (IsSigned) {
256
+ c = a.sext (bitWidth * 2 ) * b.sext (bitWidth * 2 );
257
+ } else {
258
+ c = a.zext (bitWidth * 2 ) * b.zext (bitWidth * 2 );
259
+ }
260
+ return c.extractBits (bitWidth, bitWidth); // Extract high result
261
+ });
262
+
263
+ if (!highBits)
264
+ return failure ();
265
+
266
+ Value lowBitsVal =
267
+ rewriter.create <spirv::ConstantOp>(loc, constituentType, lowBits);
268
+
269
+ Value highBitsVal =
270
+ rewriter.create <spirv::ConstantOp>(loc, constituentType, highBits);
271
+
272
+ // Create empty struct
273
+ Value undef = rewriter.create <spirv::UndefOp>(loc, op.getType ());
274
+ // Fill in lowBits at id 0
275
+ Value intermediate =
276
+ rewriter.create <spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0 );
277
+ // Fill in highBits at id 1
278
+ rewriter.replaceOpWithNewOp <spirv::CompositeInsertOp>(op, highBitsVal,
279
+ intermediate, 1 );
280
+ return success ();
281
+ }
282
+ };
283
+
284
+ using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true >;
285
+ void spirv::SMulExtendedOp::getCanonicalizationPatterns (
286
+ RewritePatternSet &patterns, MLIRContext *context) {
287
+ patterns.add <SMulExtendedOpFold>(context);
288
+ }
289
+
290
+ struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
291
+ using OpRewritePattern::OpRewritePattern;
292
+
293
+ LogicalResult matchAndRewrite (spirv::UMulExtendedOp op,
294
+ PatternRewriter &rewriter) const override {
295
+ Location loc = op.getLoc ();
296
+ Value lhs = op.getOperand1 ();
297
+ Value rhs = op.getOperand2 ();
298
+ Type constituentType = lhs.getType ();
299
+
300
+ // umulextended (x, 1) = <x, 0>
301
+ if (matchPattern (rhs, m_One ())) {
302
+ Value zero = spirv::ConstantOp::getZero (constituentType, loc, rewriter);
303
+ Value constituents[2 ] = {lhs, zero};
304
+ rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(op, op.getType (),
305
+ constituents);
306
+ return success ();
307
+ }
308
+
309
+ return failure ();
310
+ }
311
+ };
312
+
313
+ using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false >;
314
+ void spirv::UMulExtendedOp::getCanonicalizationPatterns (
315
+ RewritePatternSet &patterns, MLIRContext *context) {
316
+ patterns.add <UMulExtendedOpFold, UMulExtendedOpXOne>(context);
317
+ }
318
+
125
319
// ===----------------------------------------------------------------------===//
126
320
// spirv.UMod
127
321
// ===----------------------------------------------------------------------===//
0 commit comments