Skip to content

Commit a0cb359

Browse files
committed
switch to using CompositeInsert from Construct
- commit to demonstrate how we could potentially use CompositeInsert instead of CompositeConstruct
1 parent fc568a2 commit a0cb359

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,14 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
177177
Value carrysVal =
178178
rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
179179

180-
Value constituents[2] = {addsVal, carrysVal};
181-
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
182-
constituents);
180+
// Create empty struct
181+
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
182+
// Fill in adds at id 0
183+
Value intermediate =
184+
rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
185+
// Fill in carrys at id 1
186+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
187+
intermediate, 1);
183188
return success();
184189
}
185190
};
@@ -257,9 +262,14 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
257262
Value highBitsVal =
258263
rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
259264

260-
Value constituents[2] = {lowBitsVal, highBitsVal};
261-
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
262-
constituents);
265+
// Create empty struct
266+
Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
267+
// Fill in lowBits at id 0
268+
Value intermediate =
269+
rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
270+
// Fill in highBits at id 1
271+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
272+
intermediate, 1);
263273
return success();
264274
}
265275
};

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,14 @@ func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.s
358358

359359
// CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
360360
// CHECK-DAG: %[[CN3:.*]] = spirv.Constant -3
361-
// CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeConstruct %[[CN3]], %[[C0]]
361+
// CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
362+
// CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN3]], %[[UNDEF1]][0 : i32]
363+
// CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER1]][1 : i32]
362364
// CHECK-DAG: %[[C1:.*]] = spirv.Constant 1
363365
// CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
364-
// CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeConstruct %[[CN13]], %[[C1]]
366+
// CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
367+
// CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[CN13]], %[[UNDEF2]][0 : i32]
368+
// CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeInsert %[[C1]], %[[INTER2]][1 : i32]
365369
%0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
366370
%1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
367371

@@ -376,7 +380,9 @@ func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector
376380

377381
// CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[-3, -11, 0]>
378382
// CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[0, 1, 1]>
379-
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
383+
// CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
384+
// CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
385+
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
380386
%0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
381387

382388
// CHECK: return %[[CC_CV1_CV2]]
@@ -472,10 +478,14 @@ func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spir
472478

473479
// CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
474480
// CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1
475-
// CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeConstruct %[[CN40]], %[[CN1]]
481+
// CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
482+
// CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
483+
// CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeInsert %[[CN1]], %[[INTER1]]
476484
// CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
477485
// CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
478-
// CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeConstruct %[[C40]], %[[C0]]
486+
// CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
487+
// CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
488+
// CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER2]][1 : i32]
479489
%0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
480490
%1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
481491

@@ -490,7 +500,9 @@ func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vec
490500

491501
// CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
492502
// CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, 0, -1]>
493-
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
503+
// CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
504+
// CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
505+
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
494506
%0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
495507

496508
// CHECK: return %[[CC_CV1_CV2]]
@@ -533,12 +545,17 @@ func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spir
533545
%cn5 = spirv.Constant -5 : i32
534546
%cn8 = spirv.Constant -8 : i32
535547

548+
536549
// CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
537550
// CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
538-
// CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeConstruct %[[C40]], %[[CN13]]
539551
// CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
540552
// CHECK-DAG: %[[C4:.*]] = spirv.Constant 4
541-
// CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeConstruct %[[CN40]], %[[C4]]
553+
// CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
554+
// CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
555+
// CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeInsert %[[C4]], %[[INTER1]][1 : i32]
556+
// CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
557+
// CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
558+
// CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeInsert %[[CN13]], %[[INTER2]][1 : i32]
542559
%0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
543560
%1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
544561

@@ -553,7 +570,9 @@ func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vec
553570

554571
// CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
555572
// CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, -13, 0]>
556-
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
573+
// CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
574+
// CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]]
575+
// CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]]
557576
%0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
558577

559578
// CHECK: return %[[CC_CV1_CV2]]

0 commit comments

Comments
 (0)