Skip to content

Commit 14028ec

Browse files
authored
[mlir][spirv] Add canon patterns for IAddCarry/[S|U]MulExtended (#73340)
Add missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns. Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt. This helps for readability of lowered code into SPIR-V. Part of work for #70704
1 parent ce00133 commit 14028ec

File tree

3 files changed

+382
-0
lines changed

3 files changed

+382
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
316316
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
317317
```
318318
}];
319+
320+
let hasCanonicalizer = 1;
319321
}
320322

321323
// -----
@@ -551,6 +553,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
551553
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
552554
```
553555
}];
556+
557+
let hasCanonicalizer = 1;
554558
}
555559

556560
// -----
@@ -675,6 +679,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
675679
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
676680
```
677681
}];
682+
683+
let hasCanonicalizer = 1;
678684
}
679685

680686
// -----

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,200 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
122122
results.add<CombineChainedAccessChain>(context);
123123
}
124124

125+
//===----------------------------------------------------------------------===//
126+
// spirv.IAddCarry
127+
//===----------------------------------------------------------------------===//
128+
129+
// We are required to use CompositeConstructOp to create a constant struct as
130+
// they are not yet implemented as constant, hence we can not do so in a fold.
131+
struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
132+
using OpRewritePattern::OpRewritePattern;
133+
134+
LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
135+
PatternRewriter &rewriter) const override {
136+
Location loc = op.getLoc();
137+
Value lhs = op.getOperand1();
138+
Value rhs = op.getOperand2();
139+
Type constituentType = lhs.getType();
140+
141+
// iaddcarry (x, 0) = <0, x>
142+
if (matchPattern(rhs, m_Zero())) {
143+
Value constituents[2] = {rhs, lhs};
144+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145+
constituents);
146+
return success();
147+
}
148+
149+
// According to the SPIR-V spec:
150+
//
151+
// Result Type must be from OpTypeStruct. The struct must have two
152+
// members...
153+
//
154+
// Member 0 of the result gets the low-order bits (full component width) of
155+
// the addition.
156+
//
157+
// Member 1 of the result gets the high-order (carry) bit of the result of
158+
// the addition. That is, it gets the value 1 if the addition overflowed
159+
// the component width, and 0 otherwise.
160+
Attribute lhsAttr;
161+
Attribute rhsAttr;
162+
if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
163+
!matchPattern(rhs, m_Constant(&rhsAttr)))
164+
return failure();
165+
166+
auto adds = constFoldBinaryOp<IntegerAttr>(
167+
{lhsAttr, rhsAttr},
168+
[](const APInt &a, const APInt &b) { return a + b; });
169+
if (!adds)
170+
return failure();
171+
172+
auto carrys = constFoldBinaryOp<IntegerAttr>(
173+
ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
174+
APInt zero = APInt::getZero(a.getBitWidth());
175+
return a.ult(b) ? (zero + 1) : zero;
176+
});
177+
178+
if (!carrys)
179+
return failure();
180+
181+
Value addsVal =
182+
rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
183+
184+
Value carrysVal =
185+
rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
186+
187+
// Create empty struct
188+
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
189+
// Fill in adds at id 0
190+
Value intermediate =
191+
rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
192+
// Fill in carrys at id 1
193+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
194+
intermediate, 1);
195+
return success();
196+
}
197+
};
198+
199+
void spirv::IAddCarryOp::getCanonicalizationPatterns(
200+
RewritePatternSet &patterns, MLIRContext *context) {
201+
patterns.add<IAddCarryFold>(context);
202+
}
203+
204+
//===----------------------------------------------------------------------===//
205+
// spirv.[S|U]MulExtended
206+
//===----------------------------------------------------------------------===//
207+
208+
// We are required to use CompositeConstructOp to create a constant struct as
209+
// they are not yet implemented as constant, hence we can not do so in a fold.
210+
template <typename MulOp, bool IsSigned>
211+
struct MulExtendedFold final : OpRewritePattern<MulOp> {
212+
using OpRewritePattern<MulOp>::OpRewritePattern;
213+
214+
LogicalResult matchAndRewrite(MulOp op,
215+
PatternRewriter &rewriter) const override {
216+
Location loc = op.getLoc();
217+
Value lhs = op.getOperand1();
218+
Value rhs = op.getOperand2();
219+
Type constituentType = lhs.getType();
220+
221+
// [su]mulextended (x, 0) = <0, 0>
222+
if (matchPattern(rhs, m_Zero())) {
223+
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
224+
Value constituents[2] = {zero, zero};
225+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
226+
constituents);
227+
return success();
228+
}
229+
230+
// According to the SPIR-V spec:
231+
//
232+
// Result Type must be from OpTypeStruct. The struct must have two
233+
// members...
234+
//
235+
// Member 0 of the result gets the low-order bits of the multiplication.
236+
//
237+
// Member 1 of the result gets the high-order bits of the multiplication.
238+
Attribute lhsAttr;
239+
Attribute rhsAttr;
240+
if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
241+
!matchPattern(rhs, m_Constant(&rhsAttr)))
242+
return failure();
243+
244+
auto lowBits = constFoldBinaryOp<IntegerAttr>(
245+
{lhsAttr, rhsAttr},
246+
[](const APInt &a, const APInt &b) { return a * b; });
247+
248+
if (!lowBits)
249+
return failure();
250+
251+
auto highBits = constFoldBinaryOp<IntegerAttr>(
252+
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253+
unsigned bitWidth = a.getBitWidth();
254+
APInt c;
255+
if (IsSigned) {
256+
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
257+
} else {
258+
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
259+
}
260+
return c.extractBits(bitWidth, bitWidth); // Extract high result
261+
});
262+
263+
if (!highBits)
264+
return failure();
265+
266+
Value lowBitsVal =
267+
rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
268+
269+
Value highBitsVal =
270+
rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
271+
272+
// Create empty struct
273+
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
274+
// Fill in lowBits at id 0
275+
Value intermediate =
276+
rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
277+
// Fill in highBits at id 1
278+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
279+
intermediate, 1);
280+
return success();
281+
}
282+
};
283+
284+
using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
285+
void spirv::SMulExtendedOp::getCanonicalizationPatterns(
286+
RewritePatternSet &patterns, MLIRContext *context) {
287+
patterns.add<SMulExtendedOpFold>(context);
288+
}
289+
290+
struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
291+
using OpRewritePattern::OpRewritePattern;
292+
293+
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
294+
PatternRewriter &rewriter) const override {
295+
Location loc = op.getLoc();
296+
Value lhs = op.getOperand1();
297+
Value rhs = op.getOperand2();
298+
Type constituentType = lhs.getType();
299+
300+
// umulextended (x, 1) = <x, 0>
301+
if (matchPattern(rhs, m_One())) {
302+
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
303+
Value constituents[2] = {lhs, zero};
304+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
305+
constituents);
306+
return success();
307+
}
308+
309+
return failure();
310+
}
311+
};
312+
313+
using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
314+
void spirv::UMulExtendedOp::getCanonicalizationPatterns(
315+
RewritePatternSet &patterns, MLIRContext *context) {
316+
patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
317+
}
318+
125319
//===----------------------------------------------------------------------===//
126320
// spirv.UMod
127321
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)