Skip to content

Commit dcbfc96

Browse files
committed
[mlir][spirv] Add folding for IAddCarry/[S|U]MulExtended
Add missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns. Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt. This helps for readability of lowered code into SPIRV. Part of work for #70704
1 parent cc21287 commit dcbfc96

File tree

3 files changed

+344
-0
lines changed

3 files changed

+344
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
379379
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
380380
```
381381
}];
382+
383+
let hasCanonicalizer = 1;
382384
}
383385

384386
// -----
@@ -607,6 +609,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
607609
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
608610
```
609611
}];
612+
613+
let hasCanonicalizer = 1;
610614
}
611615

612616
// -----
@@ -742,6 +746,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
742746
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
743747
```
744748
}];
749+
750+
let hasCanonicalizer = 1;
745751
}
746752

747753
// -----

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
115115
results.add<CombineChainedAccessChain>(context);
116116
}
117117

118+
//===----------------------------------------------------------------------===//
119+
// spirv.IAddCarry
120+
//===----------------------------------------------------------------------===//
121+
122+
// We are required to use CompositeConstructOp to create a constant struct as
123+
// they are not yet implemented as constant, hence we can not do so in a fold.
124+
struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
125+
using OpRewritePattern::OpRewritePattern;
126+
127+
LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
128+
PatternRewriter &rewriter) const override {
129+
Location loc = op.getLoc();
130+
auto operands = op.getOperands();
131+
132+
SmallVector<Value> constituents;
133+
Type constituentType = operands[0].getType();
134+
135+
// iaddcarry (x, 0) = <0, x>
136+
if (matchPattern(operands[1], m_Zero())) {
137+
constituents.push_back(operands[1]);
138+
constituents.push_back(operands[0]);
139+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
140+
constituents);
141+
return success();
142+
}
143+
144+
// According to the SPIR-V spec:
145+
//
146+
// Result Type must be from OpTypeStruct. The struct must have two
147+
// members...
148+
//
149+
// Member 0 of the result gets the low-order bits (full component width) of
150+
// the addition.
151+
//
152+
// Member 1 of the result gets the high-order (carry) bit of the result of
153+
// the addition. That is, it gets the value 1 if the addition overflowed
154+
// the component width, and 0 otherwise.
155+
Attribute lhs;
156+
Attribute rhs;
157+
if (!matchPattern(operands[0], m_Constant(&lhs)) ||
158+
!matchPattern(operands[1], m_Constant(&rhs)))
159+
return failure();
160+
161+
auto adds = constFoldBinaryOp<IntegerAttr>(
162+
{lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
163+
if (!adds)
164+
return failure();
165+
166+
auto carrys = constFoldBinaryOp<IntegerAttr>(
167+
ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
168+
APInt zero = APInt::getZero(a.getBitWidth());
169+
return a.ult(b) ? (zero + 1) : zero;
170+
});
171+
172+
if (!carrys)
173+
return failure();
174+
175+
Value addsVal =
176+
rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
177+
constituents.push_back(addsVal);
178+
179+
Value carrysVal =
180+
rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
181+
constituents.push_back(carrysVal);
182+
183+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
184+
constituents);
185+
return success();
186+
}
187+
};
188+
189+
void spirv::IAddCarryOp::getCanonicalizationPatterns(
190+
RewritePatternSet &patterns, MLIRContext *context) {
191+
patterns.add<IAddCarryFold>(context);
192+
}
193+
194+
//===----------------------------------------------------------------------===//
195+
// spirv.[S|U]MulExtended
196+
//===----------------------------------------------------------------------===//
197+
198+
// We are required to use CompositeConstructOp to create a constant struct as
199+
// they are not yet implemented as constant, hence we can not do so in a fold.
200+
template <typename MulOp, bool IsSigned>
201+
struct MulExtendedFold final : OpRewritePattern<MulOp> {
202+
using OpRewritePattern<MulOp>::OpRewritePattern;
203+
204+
LogicalResult matchAndRewrite(MulOp op,
205+
PatternRewriter &rewriter) const override {
206+
Location loc = op.getLoc();
207+
auto operands = op.getOperands();
208+
209+
SmallVector<Value> constituents;
210+
Type constituentType = operands[0].getType();
211+
212+
// [su]mulextended (x, 0) = <0, 0>
213+
if (matchPattern(operands[1], m_Zero())) {
214+
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
215+
constituents.push_back(zero);
216+
constituents.push_back(zero);
217+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
218+
constituents);
219+
return success();
220+
}
221+
222+
// According to the SPIR-V spec:
223+
//
224+
// Result Type must be from OpTypeStruct. The struct must have two
225+
// members...
226+
//
227+
// Member 0 of the result gets the low-order bits of the multiplication.
228+
//
229+
// Member 1 of the result gets the high-order bits of the multiplication.
230+
Attribute lhs;
231+
Attribute rhs;
232+
if (!matchPattern(operands[0], m_Constant(&lhs)) ||
233+
!matchPattern(operands[1], m_Constant(&rhs)))
234+
return failure();
235+
236+
auto lowBits = constFoldBinaryOp<IntegerAttr>(
237+
{lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
238+
239+
if (!lowBits)
240+
return failure();
241+
242+
auto highBits = constFoldBinaryOp<IntegerAttr>(
243+
{lhs, rhs}, [](const APInt &a, const APInt &b) {
244+
unsigned bitWidth = a.getBitWidth();
245+
APInt c;
246+
if (IsSigned) {
247+
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
248+
} else {
249+
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
250+
}
251+
return c.extractBits(bitWidth, bitWidth); // Extract high result
252+
});
253+
254+
if (!highBits)
255+
return failure();
256+
257+
Value lowBitsVal =
258+
rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
259+
constituents.push_back(lowBitsVal);
260+
261+
Value highBitsVal =
262+
rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
263+
constituents.push_back(highBitsVal);
264+
265+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
266+
constituents);
267+
return success();
268+
}
269+
};
270+
271+
using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
272+
void spirv::SMulExtendedOp::getCanonicalizationPatterns(
273+
RewritePatternSet &patterns, MLIRContext *context) {
274+
patterns.add<SMulExtendedOpFold>(context);
275+
}
276+
277+
struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
278+
using OpRewritePattern::OpRewritePattern;
279+
280+
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
281+
PatternRewriter &rewriter) const override {
282+
Location loc = op.getLoc();
283+
auto operands = op.getOperands();
284+
285+
SmallVector<Value> constituents;
286+
Type constituentType = operands[0].getType();
287+
288+
// umulextended (x, 1) = <x, 0>
289+
if (matchPattern(operands[1], m_One())) {
290+
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
291+
constituents.push_back(operands[0]);
292+
constituents.push_back(zero);
293+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
294+
constituents);
295+
return success();
296+
}
297+
298+
return failure();
299+
}
300+
};
301+
302+
using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
303+
void spirv::UMulExtendedOp::getCanonicalizationPatterns(
304+
RewritePatternSet &patterns, MLIRContext *context) {
305+
patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
306+
}
307+
118308
//===----------------------------------------------------------------------===//
119309
// spirv.UMod
120310
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
336336

337337
// -----
338338

339+
//===----------------------------------------------------------------------===//
340+
// spirv.IAddCarry
341+
//===----------------------------------------------------------------------===//
342+
343+
// CHECK-LABEL: @iaddcarry_x_0
344+
func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
345+
%c0 = spirv.Constant 0 : i32
346+
347+
// CHECK: spirv.CompositeConstruct
348+
%0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
349+
return %0 : !spirv.struct<(i32, i32)>
350+
}
351+
352+
// CHECK-LABEL: @const_fold_scalar_iaddcarry
353+
func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
354+
%c5 = spirv.Constant 5 : i32
355+
%cn5 = spirv.Constant -5 : i32
356+
%cn8 = spirv.Constant -8 : i32
357+
358+
// CHECK-DAG: spirv.Constant 0
359+
// CHECK-DAG: spirv.Constant -3
360+
// CHECK-DAG: spirv.CompositeConstruct
361+
// CHECK-DAG: spirv.Constant 1
362+
// CHECK-DAG: spirv.Constant -13
363+
// CHECK-DAG: spirv.CompositeConstruct
364+
%0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
365+
%1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
366+
367+
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
368+
}
369+
370+
// CHECK-LABEL: @const_fold_vector_iaddcarry
371+
func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
372+
%v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
373+
%v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
374+
375+
// CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
376+
// CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
377+
// CHECK-DAG: spirv.CompositeConstruct
378+
%0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
379+
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
380+
381+
}
382+
383+
// -----
384+
339385
//===----------------------------------------------------------------------===//
340386
// spirv.IMul
341387
//===----------------------------------------------------------------------===//
@@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
400446

401447
// -----
402448

449+
//===----------------------------------------------------------------------===//
450+
// spirv.SMulExtended
451+
//===----------------------------------------------------------------------===//
452+
453+
// CHECK-LABEL: @smulextended_x_0
454+
func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
455+
%c0 = spirv.Constant 0 : i32
456+
457+
// CHECK: spirv.CompositeConstruct
458+
%0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
459+
return %0 : !spirv.struct<(i32, i32)>
460+
}
461+
462+
// CHECK-LABEL: @const_fold_scalar_smulextended
463+
func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
464+
%c5 = spirv.Constant 5 : i32
465+
%cn5 = spirv.Constant -5 : i32
466+
%cn8 = spirv.Constant -8 : i32
467+
468+
// CHECK-DAG: spirv.Constant -40
469+
// CHECK-DAG: spirv.Constant -1
470+
// CHECK-DAG: spirv.CompositeConstruct
471+
// CHECK-DAG: spirv.Constant 40
472+
// CHECK-DAG: spirv.Constant 0
473+
// CHECK-DAG: spirv.CompositeConstruct
474+
%0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
475+
%1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
476+
477+
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
478+
}
479+
480+
// CHECK-LABEL: @const_fold_vector_smulextended
481+
func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
482+
%v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
483+
%v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
484+
485+
// CHECK: spirv.Constant dense<[2147483643, 40, -1]>
486+
// CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
487+
// CHECK-NEXT: spirv.CompositeConstruct
488+
%0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
489+
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
490+
491+
}
492+
493+
// -----
494+
495+
//===----------------------------------------------------------------------===//
496+
// spirv.UMulExtended
497+
//===----------------------------------------------------------------------===//
498+
499+
// CHECK-LABEL: @umulextended_x_0
500+
func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
501+
%c0 = spirv.Constant 0 : i32
502+
503+
// CHECK: spirv.CompositeConstruct
504+
%0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
505+
return %0 : !spirv.struct<(i32, i32)>
506+
}
507+
508+
// CHECK-LABEL: @umulextended_x_1
509+
func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
510+
%c0 = spirv.Constant 1 : i32
511+
512+
// CHECK: spirv.CompositeConstruct
513+
%0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
514+
return %0 : !spirv.struct<(i32, i32)>
515+
}
516+
517+
// CHECK-LABEL: @const_fold_scalar_umulextended
518+
func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
519+
%c5 = spirv.Constant 5 : i32
520+
%cn5 = spirv.Constant -5 : i32
521+
%cn8 = spirv.Constant -8 : i32
522+
523+
// CHECK-DAG: spirv.Constant 40
524+
// CHECK-DAG: spirv.Constant -13
525+
// CHECK-DAG: spirv.CompositeConstruct
526+
// CHECK-DAG: spirv.Constant -40
527+
// CHECK-DAG: spirv.Constant 4
528+
// CHECK-DAG: spirv.CompositeConstruct
529+
%0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
530+
%1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
531+
532+
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
533+
}
534+
535+
// CHECK-LABEL: @const_fold_vector_umulextended
536+
func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
537+
%v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
538+
%v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
539+
540+
// CHECK: spirv.Constant dense<[2147483643, 40, -1]>
541+
// CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
542+
// CHECK-NEXT: spirv.CompositeConstruct
543+
%0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
544+
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
545+
546+
}
547+
548+
// -----
549+
550+
403551
//===----------------------------------------------------------------------===//
404552
// spirv.ISub
405553
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)