@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
50
50
assert (targetBits % sourceBits == 0 );
51
51
Type type = srcIdx.getType ();
52
52
IntegerAttr idxAttr = builder.getIntegerAttr (type, targetBits / sourceBits);
53
- auto idx = builder.create <spirv::ConstantOp>(loc, type, idxAttr);
53
+ auto idx = builder.createOrFold <spirv::ConstantOp>(loc, type, idxAttr);
54
54
IntegerAttr srcBitsAttr = builder.getIntegerAttr (type, sourceBits);
55
- auto srcBitsValue = builder.create <spirv::ConstantOp>(loc, type, srcBitsAttr);
56
- auto m = builder.create <spirv::UModOp>(loc, srcIdx, idx);
57
- return builder.create <spirv::IMulOp>(loc, type, m, srcBitsValue);
55
+ auto srcBitsValue =
56
+ builder.createOrFold <spirv::ConstantOp>(loc, type, srcBitsAttr);
57
+ auto m = builder.createOrFold <spirv::UModOp>(loc, srcIdx, idx);
58
+ return builder.createOrFold <spirv::IMulOp>(loc, type, m, srcBitsValue);
58
59
}
59
60
60
61
// / Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
74
75
Value lastDim = op->getOperand (op.getNumOperands () - 1 );
75
76
Type type = lastDim.getType ();
76
77
IntegerAttr attr = builder.getIntegerAttr (type, targetBits / sourceBits);
77
- auto idx = builder.create <spirv::ConstantOp>(loc, type, attr);
78
+ auto idx = builder.createOrFold <spirv::ConstantOp>(loc, type, attr);
78
79
auto indices = llvm::to_vector<4 >(op.getIndices ());
79
80
// There are two elements if this is a 1-D tensor.
80
81
assert (indices.size () == 2 );
81
- indices.back () = builder.create <spirv::SDivOp>(loc, lastDim, idx);
82
+ indices.back () = builder.createOrFold <spirv::SDivOp>(loc, lastDim, idx);
82
83
Type t = typeConverter.convertType (op.getComponentPtr ().getType ());
83
84
return builder.create <spirv::AccessChainOp>(loc, t, op.getBasePtr (), indices);
84
85
}
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
91
92
return srcBool;
92
93
Value zero = spirv::ConstantOp::getZero (dstType, loc, builder);
93
94
Value one = spirv::ConstantOp::getOne (dstType, loc, builder);
94
- return builder.create <spirv::SelectOp>(loc, dstType, srcBool, one, zero);
95
+ return builder.createOrFold <spirv::SelectOp>(loc, dstType, srcBool, one,
96
+ zero);
95
97
}
96
98
97
99
// / Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
111
113
loc, builder.getIntegerType (targetBits), value);
112
114
}
113
115
114
- value = builder.create <spirv::BitwiseAndOp>(loc, value, mask);
116
+ value = builder.createOrFold <spirv::BitwiseAndOp>(loc, value, mask);
115
117
}
116
- return builder.create <spirv::ShiftLeftLogicalOp>(loc, value.getType (), value ,
117
- offset);
118
+ return builder.createOrFold <spirv::ShiftLeftLogicalOp>(loc, value.getType (),
119
+ value, offset);
118
120
}
119
121
120
122
// / Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165
167
return srcInt;
166
168
167
169
auto one = spirv::ConstantOp::getOne (srcInt.getType (), loc, builder);
168
- return builder.create <spirv::IEqualOp>(loc, srcInt, one);
170
+ return builder.createOrFold <spirv::IEqualOp>(loc, srcInt, one);
169
171
}
170
172
171
173
// ===----------------------------------------------------------------------===//
@@ -597,25 +599,26 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
597
599
// ____XXXX________ -> ____________XXXX
598
600
Value lastDim = accessChainOp->getOperand (accessChainOp.getNumOperands () - 1 );
599
601
Value offset = getOffsetForBitwidth (loc, lastDim, srcBits, dstBits, rewriter);
600
- Value result = rewriter.create <spirv::ShiftRightArithmeticOp>(
602
+ Value result = rewriter.createOrFold <spirv::ShiftRightArithmeticOp>(
601
603
loc, spvLoadOp.getType (), spvLoadOp, offset);
602
604
603
605
// Apply the mask to extract corresponding bits.
604
- Value mask = rewriter.create <spirv::ConstantOp>(
606
+ Value mask = rewriter.createOrFold <spirv::ConstantOp>(
605
607
loc, dstType, rewriter.getIntegerAttr (dstType, (1 << srcBits) - 1 ));
606
- result = rewriter.create <spirv::BitwiseAndOp>(loc, dstType, result, mask);
608
+ result =
609
+ rewriter.createOrFold <spirv::BitwiseAndOp>(loc, dstType, result, mask);
607
610
608
611
// Apply sign extension on the loading value unconditionally. The signedness
609
612
// semantic is carried in the operator itself, we relies other pattern to
610
613
// handle the casting.
611
614
IntegerAttr shiftValueAttr =
612
615
rewriter.getIntegerAttr (dstType, dstBits - srcBits);
613
616
Value shiftValue =
614
- rewriter.create <spirv::ConstantOp>(loc, dstType, shiftValueAttr);
615
- result = rewriter.create <spirv::ShiftLeftLogicalOp>(loc, dstType, result ,
616
- shiftValue);
617
- result = rewriter.create <spirv::ShiftRightArithmeticOp>(loc, dstType, result,
618
- shiftValue);
617
+ rewriter.createOrFold <spirv::ConstantOp>(loc, dstType, shiftValueAttr);
618
+ result = rewriter.createOrFold <spirv::ShiftLeftLogicalOp>(loc, dstType,
619
+ result, shiftValue);
620
+ result = rewriter.createOrFold <spirv::ShiftRightArithmeticOp>(
621
+ loc, dstType, result, shiftValue);
619
622
620
623
rewriter.replaceOp (loadOp, result);
621
624
@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
744
747
745
748
// Create a mask to clear the destination. E.g., if it is the second i8 in
746
749
// i32, 0xFFFF00FF is created.
747
- Value mask = rewriter.create <spirv::ConstantOp>(
750
+ Value mask = rewriter.createOrFold <spirv::ConstantOp>(
748
751
loc, dstType, rewriter.getIntegerAttr (dstType, (1 << srcBits) - 1 ));
749
- Value clearBitsMask =
750
- rewriter.create <spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
751
- clearBitsMask = rewriter.create <spirv::NotOp>(loc, dstType, clearBitsMask);
752
+ Value clearBitsMask = rewriter.createOrFold <spirv::ShiftLeftLogicalOp>(
753
+ loc, dstType, mask, offset);
754
+ clearBitsMask =
755
+ rewriter.createOrFold <spirv::NotOp>(loc, dstType, clearBitsMask);
752
756
753
757
Value storeVal = shiftValue (loc, adaptor.getValue (), offset, mask, rewriter);
754
758
Value adjustedPtr = adjustAccessChainForBitwidth (typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
910
914
911
915
int64_t attrVal = cast<IntegerAttr>(offset.get <Attribute>()).getInt ();
912
916
Attribute attr = rewriter.getIntegerAttr (intType, attrVal);
913
- return rewriter.create <spirv::ConstantOp>(loc, intType, attr);
917
+ return rewriter.createOrFold <spirv::ConstantOp>(loc, intType, attr);
914
918
}();
915
919
916
920
rewriter.replaceOpWithNewOp <spirv::InBoundsPtrAccessChainOp>(
0 commit comments