@@ -77,11 +77,37 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
77
77
return builder.create <spirv::AccessChainOp>(loc, t, op.getBasePtr (), indices);
78
78
}
79
79
80
- // / Returns the shifted `targetBits`-bit value with the given offset.
80
+ // / Casts the given `srcBool` into an integer of `dstType`.
81
+ static Value castBoolToIntN (Location loc, Value srcBool, Type dstType,
82
+ OpBuilder &builder) {
83
+ assert (srcBool.getType ().isInteger (1 ));
84
+ if (dstType.isInteger (1 ))
85
+ return srcBool;
86
+ Value zero = spirv::ConstantOp::getZero (dstType, loc, builder);
87
+ Value one = spirv::ConstantOp::getOne (dstType, loc, builder);
88
+ return builder.create <spirv::SelectOp>(loc, dstType, srcBool, one, zero);
89
+ }
90
+
91
+ // / Returns the `targetBits`-bit value shifted by the given `offset`, and cast
92
+ // / to the type destination type, and masked.
81
93
static Value shiftValue (Location loc, Value value, Value offset, Value mask,
82
- int targetBits, OpBuilder &builder) {
83
- Value result = builder.create <spirv::BitwiseAndOp>(loc, value, mask);
84
- return builder.create <spirv::ShiftLeftLogicalOp>(loc, value.getType (), result,
94
+ OpBuilder &builder) {
95
+ IntegerType dstType = cast<IntegerType>(mask.getType ());
96
+ int targetBits = static_cast <int >(dstType.getWidth ());
97
+ int valueBits = value.getType ().getIntOrFloatBitWidth ();
98
+ assert (valueBits <= targetBits);
99
+
100
+ if (valueBits == 1 ) {
101
+ value = castBoolToIntN (loc, value, dstType, builder);
102
+ } else {
103
+ if (valueBits < targetBits) {
104
+ value = builder.create <spirv::UConvertOp>(
105
+ loc, builder.getIntegerType (targetBits), value);
106
+ }
107
+
108
+ value = builder.create <spirv::BitwiseAndOp>(loc, value, mask);
109
+ }
110
+ return builder.create <spirv::ShiftLeftLogicalOp>(loc, value.getType (), value,
85
111
offset);
86
112
}
87
113
@@ -136,17 +162,6 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
136
162
return builder.create <spirv::IEqualOp>(loc, srcInt, one);
137
163
}
138
164
139
- // / Casts the given `srcBool` into an integer of `dstType`.
140
- static Value castBoolToIntN (Location loc, Value srcBool, Type dstType,
141
- OpBuilder &builder) {
142
- assert (srcBool.getType ().isInteger (1 ));
143
- if (dstType.isInteger (1 ))
144
- return srcBool;
145
- Value zero = spirv::ConstantOp::getZero (dstType, loc, builder);
146
- Value one = spirv::ConstantOp::getOne (dstType, loc, builder);
147
- return builder.create <spirv::SelectOp>(loc, dstType, srcBool, one, zero);
148
- }
149
-
150
165
// ===----------------------------------------------------------------------===//
151
166
// Operation conversion
152
167
// ===----------------------------------------------------------------------===//
@@ -553,7 +568,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
553
568
ConversionPatternRewriter &rewriter) const {
554
569
auto memrefType = cast<MemRefType>(storeOp.getMemref ().getType ());
555
570
if (!memrefType.getElementType ().isSignlessInteger ())
556
- return failure ();
571
+ return rewriter.notifyMatchFailure (storeOp,
572
+ " element type is not a signless int" );
557
573
558
574
auto loc = storeOp.getLoc ();
559
575
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
@@ -562,7 +578,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
562
578
adaptor.getIndices (), loc, rewriter);
563
579
564
580
if (!accessChain)
565
- return failure ();
581
+ return rewriter.notifyMatchFailure (
582
+ storeOp, " failed to convert element pointer type" );
566
583
567
584
int srcBits = memrefType.getElementType ().getIntOrFloatBitWidth ();
568
585
@@ -576,23 +593,28 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
576
593
" failed to convert memref type" );
577
594
578
595
Type pointeeType = pointerType.getPointeeType ();
579
- Type dstType;
596
+ IntegerType dstType;
580
597
if (typeConverter.allows (spirv::Capability::Kernel)) {
581
598
if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
582
- dstType = arrayType.getElementType ();
599
+ dstType = dyn_cast<IntegerType>( arrayType.getElementType () );
583
600
else
584
- dstType = pointeeType;
601
+ dstType = dyn_cast<IntegerType>( pointeeType) ;
585
602
} else {
586
603
// For Vulkan we need to extract element from wrapping struct and array.
587
604
Type structElemType =
588
605
cast<spirv::StructType>(pointeeType).getElementType (0 );
589
606
if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
590
- dstType = arrayType.getElementType ();
607
+ dstType = dyn_cast<IntegerType>( arrayType.getElementType () );
591
608
else
592
- dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType ();
609
+ dstType = dyn_cast<IntegerType>(
610
+ cast<spirv::RuntimeArrayType>(structElemType).getElementType ());
593
611
}
594
612
595
- int dstBits = dstType.getIntOrFloatBitWidth ();
613
+ if (!dstType)
614
+ return rewriter.notifyMatchFailure (
615
+ storeOp, " failed to determine destination element type" );
616
+
617
+ int dstBits = static_cast <int >(dstType.getWidth ());
596
618
assert (dstBits % srcBits == 0 );
597
619
598
620
if (srcBits == dstBits) {
@@ -612,17 +634,17 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
612
634
if (!accessChainOp)
613
635
return failure ();
614
636
615
- // Since there are multi threads in the processing, the emulation will be done
616
- // with atomic operations. E.g., if the storing value is i8, rewrite the
617
- // StoreOp to
637
+ // Since there are multiple threads in the processing, the emulation will be
638
+ // done with atomic operations. E.g., if the stored value is i8, rewrite the
639
+ // StoreOp to:
618
640
// 1) load a 32-bit integer
619
- // 2) clear 8 bits in the loading value
620
- // 3) store 32-bit value back
621
- // 4) load a 32-bit integer
622
- // 5) modify 8 bits in the loading value
623
- // 6) store 32-bit value back
624
- // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
625
- // 4 to step 6 are done by AtomicOr as another atomic step.
641
+ // 2) clear 8 bits in the loaded value
642
+ // 3) set 8 bits in the loaded value
643
+ // 4) store 32-bit value back
644
+ //
645
+ // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
646
+ // loaded 32-bit value and the shifted 8-bit store value) as another atomic
647
+ // step.
626
648
assert (accessChainOp.getIndices ().size () == 2 );
627
649
Value lastDim = accessChainOp->getOperand (accessChainOp.getNumOperands () - 1 );
628
650
Value offset = getOffsetForBitwidth (loc, lastDim, srcBits, dstBits, rewriter);
@@ -635,15 +657,13 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
635
657
rewriter.create <spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
636
658
clearBitsMask = rewriter.create <spirv::NotOp>(loc, dstType, clearBitsMask);
637
659
638
- Value storeVal = adaptor.getValue ();
639
- if (isBool)
640
- storeVal = castBoolToIntN (loc, storeVal, dstType, rewriter);
641
- storeVal = shiftValue (loc, storeVal, offset, mask, dstBits, rewriter);
660
+ Value storeVal = shiftValue (loc, adaptor.getValue (), offset, mask, rewriter);
642
661
Value adjustedPtr = adjustAccessChainForBitwidth (typeConverter, accessChainOp,
643
662
srcBits, dstBits, rewriter);
644
663
std::optional<spirv::Scope> scope = getAtomicOpScope (memrefType);
645
664
if (!scope)
646
- return failure ();
665
+ return rewriter.notifyMatchFailure (storeOp, " atomic scope not available" );
666
+
647
667
Value result = rewriter.create <spirv::AtomicAndOp>(
648
668
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
649
669
clearBitsMask);
@@ -740,13 +760,13 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
740
760
ConversionPatternRewriter &rewriter) const {
741
761
auto memrefType = cast<MemRefType>(storeOp.getMemref ().getType ());
742
762
if (memrefType.getElementType ().isSignlessInteger ())
743
- return failure ( );
763
+ return rewriter. notifyMatchFailure (storeOp, " signless int " );
744
764
auto storePtr = spirv::getElementPtr (
745
765
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref (),
746
766
adaptor.getIndices (), storeOp.getLoc (), rewriter);
747
767
748
768
if (!storePtr)
749
- return failure ( );
769
+ return rewriter. notifyMatchFailure (storeOp, " type conversion failed " );
750
770
751
771
rewriter.replaceOpWithNewOp <spirv::StoreOp>(storeOp, storePtr,
752
772
adaptor.getValue ());
0 commit comments