Skip to content

Commit fc568a2

Browse files
committed
review comments
- improve readability with lhs/rhs instead of operands[0]/[1] - use stack array instead of llvm::SmallVector - increase strictness of tests to ensure proper CompositeConstruct and return order
1 parent dcbfc96 commit fc568a2

File tree

2 files changed

+83
-74
lines changed

2 files changed

+83
-74
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,13 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
127127
LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
128128
PatternRewriter &rewriter) const override {
129129
Location loc = op.getLoc();
130-
auto operands = op.getOperands();
131-
132-
SmallVector<Value> constituents;
133-
Type constituentType = operands[0].getType();
130+
Value lhs = op.getOperand1();
131+
Value rhs = op.getOperand2();
132+
Type constituentType = lhs.getType();
134133

135134
// iaddcarry (x, 0) = <0, x>
136-
if (matchPattern(operands[1], m_Zero())) {
137-
constituents.push_back(operands[1]);
138-
constituents.push_back(operands[0]);
135+
if (matchPattern(rhs, m_Zero())) {
136+
Value constituents[2] = {rhs, lhs};
139137
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
140138
constituents);
141139
return success();
@@ -152,19 +150,20 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
152150
// Member 1 of the result gets the high-order (carry) bit of the result of
153151
// the addition. That is, it gets the value 1 if the addition overflowed
154152
// the component width, and 0 otherwise.
155-
Attribute lhs;
156-
Attribute rhs;
157-
if (!matchPattern(operands[0], m_Constant(&lhs)) ||
158-
!matchPattern(operands[1], m_Constant(&rhs)))
153+
Attribute lhsAttr;
154+
Attribute rhsAttr;
155+
if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
156+
!matchPattern(rhs, m_Constant(&rhsAttr)))
159157
return failure();
160158

161159
auto adds = constFoldBinaryOp<IntegerAttr>(
162-
{lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
160+
{lhsAttr, rhsAttr},
161+
[](const APInt &a, const APInt &b) { return a + b; });
163162
if (!adds)
164163
return failure();
165164

166165
auto carrys = constFoldBinaryOp<IntegerAttr>(
167-
ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
166+
ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
168167
APInt zero = APInt::getZero(a.getBitWidth());
169168
return a.ult(b) ? (zero + 1) : zero;
170169
});
@@ -174,12 +173,11 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
174173

175174
Value addsVal =
176175
rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
177-
constituents.push_back(addsVal);
178176

179177
Value carrysVal =
180178
rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
181-
constituents.push_back(carrysVal);
182179

180+
Value constituents[2] = {addsVal, carrysVal};
183181
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
184182
constituents);
185183
return success();
@@ -204,16 +202,14 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
204202
LogicalResult matchAndRewrite(MulOp op,
205203
PatternRewriter &rewriter) const override {
206204
Location loc = op.getLoc();
207-
auto operands = op.getOperands();
208-
209-
SmallVector<Value> constituents;
210-
Type constituentType = operands[0].getType();
205+
Value lhs = op.getOperand1();
206+
Value rhs = op.getOperand2();
207+
Type constituentType = lhs.getType();
211208

212209
// [su]mulextended (x, 0) = <0, 0>
213-
if (matchPattern(operands[1], m_Zero())) {
210+
if (matchPattern(rhs, m_Zero())) {
214211
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
215-
constituents.push_back(zero);
216-
constituents.push_back(zero);
212+
Value constituents[2] = {zero, zero};
217213
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
218214
constituents);
219215
return success();
@@ -227,20 +223,21 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
227223
// Member 0 of the result gets the low-order bits of the multiplication.
228224
//
229225
// Member 1 of the result gets the high-order bits of the multiplication.
230-
Attribute lhs;
231-
Attribute rhs;
232-
if (!matchPattern(operands[0], m_Constant(&lhs)) ||
233-
!matchPattern(operands[1], m_Constant(&rhs)))
226+
Attribute lhsAttr;
227+
Attribute rhsAttr;
228+
if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
229+
!matchPattern(rhs, m_Constant(&rhsAttr)))
234230
return failure();
235231

236232
auto lowBits = constFoldBinaryOp<IntegerAttr>(
237-
{lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
233+
{lhsAttr, rhsAttr},
234+
[](const APInt &a, const APInt &b) { return a * b; });
238235

239236
if (!lowBits)
240237
return failure();
241238

242239
auto highBits = constFoldBinaryOp<IntegerAttr>(
243-
{lhs, rhs}, [](const APInt &a, const APInt &b) {
240+
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
244241
unsigned bitWidth = a.getBitWidth();
245242
APInt c;
246243
if (IsSigned) {
@@ -256,12 +253,11 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
256253

257254
Value lowBitsVal =
258255
rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
259-
constituents.push_back(lowBitsVal);
260256

261257
Value highBitsVal =
262258
rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
263-
constituents.push_back(highBitsVal);
264259

260+
Value constituents[2] = {lowBitsVal, highBitsVal};
265261
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
266262
constituents);
267263
return success();
@@ -280,16 +276,14 @@ struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
280276
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
281277
PatternRewriter &rewriter) const override {
282278
Location loc = op.getLoc();
283-
auto operands = op.getOperands();
284-
285-
SmallVector<Value> constituents;
286-
Type constituentType = operands[0].getType();
279+
Value lhs = op.getOperand1();
280+
Value rhs = op.getOperand2();
281+
Type constituentType = lhs.getType();
287282

288283
// umulextended (x, 1) = <x, 0>
289-
if (matchPattern(operands[1], m_One())) {
284+
if (matchPattern(rhs, m_One())) {
290285
Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
291-
constituents.push_back(operands[0]);
292-
constituents.push_back(zero);
286+
Value constituents[2] = {lhs, zero};
293287
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
294288
constituents);
295289
return success();

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,11 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
342342

343343
// CHECK-LABEL: @iaddcarry_x_0
344344
func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
345+
// CHECK: %[[RET:.*]] = spirv.CompositeConstruct
345346
%c0 = spirv.Constant 0 : i32
346-
347-
// CHECK: spirv.CompositeConstruct
348347
%0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
348+
349+
// CHECK: return %[[RET]]
349350
return %0 : !spirv.struct<(i32, i32)>
350351
}
351352

@@ -355,15 +356,16 @@ func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.s
355356
%cn5 = spirv.Constant -5 : i32
356357
%cn8 = spirv.Constant -8 : i32
357358

358-
// CHECK-DAG: spirv.Constant 0
359-
// CHECK-DAG: spirv.Constant -3
360-
// CHECK-DAG: spirv.CompositeConstruct
361-
// CHECK-DAG: spirv.Constant 1
362-
// CHECK-DAG: spirv.Constant -13
363-
// CHECK-DAG: spirv.CompositeConstruct
359+
// CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
360+
// CHECK-DAG: %[[CN3:.*]] = spirv.Constant -3
361+
// CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeConstruct %[[CN3]], %[[C0]]
362+
// CHECK-DAG: %[[C1:.*]] = spirv.Constant 1
363+
// CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
364+
// CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeConstruct %[[CN13]], %[[C1]]
364365
%0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
365366
%1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
366367

368+
// CHECK: return %[[CC_CN3_C0]], %[[CC_CN13_C1]]
367369
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
368370
}
369371

@@ -372,12 +374,13 @@ func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector
372374
%v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
373375
%v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
374376

375-
// CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
376-
// CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
377-
// CHECK-DAG: spirv.CompositeConstruct
377+
// CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[-3, -11, 0]>
378+
// CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[0, 1, 1]>
379+
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
378380
%0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
379-
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
380381

382+
// CHECK: return %[[CC_CV1_CV2]]
383+
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
381384
}
382385

383386
// -----
@@ -452,10 +455,12 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
452455

453456
// CHECK-LABEL: @smulextended_x_0
454457
func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
458+
// CHECK: %[[C0:.*]] = spirv.Constant 0
459+
// CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[C0]], %[[C0]]
455460
%c0 = spirv.Constant 0 : i32
456-
457-
// CHECK: spirv.CompositeConstruct
458461
%0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
462+
463+
// CHECK: return %[[RET]]
459464
return %0 : !spirv.struct<(i32, i32)>
460465
}
461466

@@ -465,15 +470,16 @@ func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spir
465470
%cn5 = spirv.Constant -5 : i32
466471
%cn8 = spirv.Constant -8 : i32
467472

468-
// CHECK-DAG: spirv.Constant -40
469-
// CHECK-DAG: spirv.Constant -1
470-
// CHECK-DAG: spirv.CompositeConstruct
471-
// CHECK-DAG: spirv.Constant 40
472-
// CHECK-DAG: spirv.Constant 0
473-
// CHECK-DAG: spirv.CompositeConstruct
473+
// CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
474+
// CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1
475+
// CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeConstruct %[[CN40]], %[[CN1]]
476+
// CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
477+
// CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
478+
// CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeConstruct %[[C40]], %[[C0]]
474479
%0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
475480
%1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
476481

482+
// CHECK: return %[[CC_CN40_CN1]], %[[CC_C40_C0]]
477483
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
478484
}
479485

@@ -482,10 +488,12 @@ func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vec
482488
%v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
483489
%v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
484490

485-
// CHECK: spirv.Constant dense<[2147483643, 40, -1]>
486-
// CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
487-
// CHECK-NEXT: spirv.CompositeConstruct
491+
// CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
492+
// CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, 0, -1]>
493+
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
488494
%0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
495+
496+
// CHECK: return %[[CC_CV1_CV2]]
489497
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
490498

491499
}
@@ -498,19 +506,24 @@ func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vec
498506

499507
// CHECK-LABEL: @umulextended_x_0
500508
func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
509+
// CHECK: %[[C0:.*]] = spirv.Constant 0
510+
// CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[C0]], %[[C0]]
501511
%c0 = spirv.Constant 0 : i32
502-
503-
// CHECK: spirv.CompositeConstruct
504512
%0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
513+
514+
// CHECK: return %[[RET]]
505515
return %0 : !spirv.struct<(i32, i32)>
506516
}
507517

508518
// CHECK-LABEL: @umulextended_x_1
519+
// CHECK-SAME: (%[[ARG:.*]]: i32)
509520
func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
521+
// CHECK: %[[C0:.*]] = spirv.Constant 0
522+
// CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[ARG]], %[[C0]]
510523
%c0 = spirv.Constant 1 : i32
511-
512-
// CHECK: spirv.CompositeConstruct
513524
%0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
525+
526+
// CHECK: return %[[RET]]
514527
return %0 : !spirv.struct<(i32, i32)>
515528
}
516529

@@ -520,15 +533,16 @@ func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spir
520533
%cn5 = spirv.Constant -5 : i32
521534
%cn8 = spirv.Constant -8 : i32
522535

523-
// CHECK-DAG: spirv.Constant 40
524-
// CHECK-DAG: spirv.Constant -13
525-
// CHECK-DAG: spirv.CompositeConstruct
526-
// CHECK-DAG: spirv.Constant -40
527-
// CHECK-DAG: spirv.Constant 4
528-
// CHECK-DAG: spirv.CompositeConstruct
536+
// CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
537+
// CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
538+
// CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeConstruct %[[C40]], %[[CN13]]
539+
// CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
540+
// CHECK-DAG: %[[C4:.*]] = spirv.Constant 4
541+
// CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeConstruct %[[CN40]], %[[C4]]
529542
%0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
530543
%1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
531544

545+
// CHECK: return %[[CC_CN40_C4]], %[[CC_C40_CN13]]
532546
return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
533547
}
534548

@@ -537,12 +551,13 @@ func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vec
537551
%v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
538552
%v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
539553

540-
// CHECK: spirv.Constant dense<[2147483643, 40, -1]>
541-
// CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
542-
// CHECK-NEXT: spirv.CompositeConstruct
554+
// CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
555+
// CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, -13, 0]>
556+
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
543557
%0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
544-
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
545558

559+
// CHECK: return %[[CC_CV1_CV2]]
560+
return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
546561
}
547562

548563
// -----

0 commit comments

Comments
 (0)