Skip to content

Commit 887e1aa

Browse files
committed
[mlir][spirv] Fix sub-word memref.store conversion
Support environments where logical types do not necessarily correspond to allowed storage access types. Also make pattern match failures more descriptive. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D159386
1 parent c1eacc3 commit 887e1aa

File tree

2 files changed

+154
-42
lines changed

2 files changed

+154
-42
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,37 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
7777
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
7878
}
7979

80-
/// Returns the shifted `targetBits`-bit value with the given offset.
80+
/// Casts the given `srcBool` into an integer of `dstType`.
81+
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
82+
OpBuilder &builder) {
83+
assert(srcBool.getType().isInteger(1));
84+
if (dstType.isInteger(1))
85+
return srcBool;
86+
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
87+
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
88+
return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
89+
}
90+
91+
/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
92+
/// to the type destination type, and masked.
8193
static Value shiftValue(Location loc, Value value, Value offset, Value mask,
82-
int targetBits, OpBuilder &builder) {
83-
Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
84-
return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), result,
94+
OpBuilder &builder) {
95+
IntegerType dstType = cast<IntegerType>(mask.getType());
96+
int targetBits = static_cast<int>(dstType.getWidth());
97+
int valueBits = value.getType().getIntOrFloatBitWidth();
98+
assert(valueBits <= targetBits);
99+
100+
if (valueBits == 1) {
101+
value = castBoolToIntN(loc, value, dstType, builder);
102+
} else {
103+
if (valueBits < targetBits) {
104+
value = builder.create<spirv::UConvertOp>(
105+
loc, builder.getIntegerType(targetBits), value);
106+
}
107+
108+
value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
109+
}
110+
return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
85111
offset);
86112
}
87113

@@ -136,17 +162,6 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
136162
return builder.create<spirv::IEqualOp>(loc, srcInt, one);
137163
}
138164

139-
/// Casts the given `srcBool` into an integer of `dstType`.
140-
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
141-
OpBuilder &builder) {
142-
assert(srcBool.getType().isInteger(1));
143-
if (dstType.isInteger(1))
144-
return srcBool;
145-
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
146-
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
147-
return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
148-
}
149-
150165
//===----------------------------------------------------------------------===//
151166
// Operation conversion
152167
//===----------------------------------------------------------------------===//
@@ -553,7 +568,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
553568
ConversionPatternRewriter &rewriter) const {
554569
auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
555570
if (!memrefType.getElementType().isSignlessInteger())
556-
return failure();
571+
return rewriter.notifyMatchFailure(storeOp,
572+
"element type is not a signless int");
557573

558574
auto loc = storeOp.getLoc();
559575
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
@@ -562,7 +578,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
562578
adaptor.getIndices(), loc, rewriter);
563579

564580
if (!accessChain)
565-
return failure();
581+
return rewriter.notifyMatchFailure(
582+
storeOp, "failed to convert element pointer type");
566583

567584
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
568585

@@ -576,23 +593,28 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
576593
"failed to convert memref type");
577594

578595
Type pointeeType = pointerType.getPointeeType();
579-
Type dstType;
596+
IntegerType dstType;
580597
if (typeConverter.allows(spirv::Capability::Kernel)) {
581598
if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
582-
dstType = arrayType.getElementType();
599+
dstType = dyn_cast<IntegerType>(arrayType.getElementType());
583600
else
584-
dstType = pointeeType;
601+
dstType = dyn_cast<IntegerType>(pointeeType);
585602
} else {
586603
// For Vulkan we need to extract element from wrapping struct and array.
587604
Type structElemType =
588605
cast<spirv::StructType>(pointeeType).getElementType(0);
589606
if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
590-
dstType = arrayType.getElementType();
607+
dstType = dyn_cast<IntegerType>(arrayType.getElementType());
591608
else
592-
dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
609+
dstType = dyn_cast<IntegerType>(
610+
cast<spirv::RuntimeArrayType>(structElemType).getElementType());
593611
}
594612

595-
int dstBits = dstType.getIntOrFloatBitWidth();
613+
if (!dstType)
614+
return rewriter.notifyMatchFailure(
615+
storeOp, "failed to determine destination element type");
616+
617+
int dstBits = static_cast<int>(dstType.getWidth());
596618
assert(dstBits % srcBits == 0);
597619

598620
if (srcBits == dstBits) {
@@ -612,17 +634,17 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
612634
if (!accessChainOp)
613635
return failure();
614636

615-
// Since there are multi threads in the processing, the emulation will be done
616-
// with atomic operations. E.g., if the storing value is i8, rewrite the
617-
// StoreOp to
637+
// Since there are multiple threads in the processing, the emulation will be
638+
// done with atomic operations. E.g., if the stored value is i8, rewrite the
639+
// StoreOp to:
618640
// 1) load a 32-bit integer
619-
// 2) clear 8 bits in the loading value
620-
// 3) store 32-bit value back
621-
// 4) load a 32-bit integer
622-
// 5) modify 8 bits in the loading value
623-
// 6) store 32-bit value back
624-
// The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
625-
// 4 to step 6 are done by AtomicOr as another atomic step.
641+
// 2) clear 8 bits in the loaded value
642+
// 3) set 8 bits in the loaded value
643+
// 4) store 32-bit value back
644+
//
645+
// Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
646+
// loaded 32-bit value and the shifted 8-bit store value) as another atomic
647+
// step.
626648
assert(accessChainOp.getIndices().size() == 2);
627649
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
628650
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
@@ -635,15 +657,13 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
635657
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
636658
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
637659

638-
Value storeVal = adaptor.getValue();
639-
if (isBool)
640-
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
641-
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
660+
Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
642661
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
643662
srcBits, dstBits, rewriter);
644663
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
645664
if (!scope)
646-
return failure();
665+
return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
666+
647667
Value result = rewriter.create<spirv::AtomicAndOp>(
648668
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
649669
clearBitsMask);
@@ -740,13 +760,13 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
740760
ConversionPatternRewriter &rewriter) const {
741761
auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
742762
if (memrefType.getElementType().isSignlessInteger())
743-
return failure();
763+
return rewriter.notifyMatchFailure(storeOp, "signless int");
744764
auto storePtr = spirv::getElementPtr(
745765
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
746766
adaptor.getIndices(), storeOp.getLoc(), rewriter);
747767

748768
if (!storePtr)
749-
return failure();
769+
return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
750770

751771
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
752772
adaptor.getValue());

mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %val
119119
// CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
120120
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
121121
// CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
122-
// CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32
123-
// CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
122+
// CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32
124123
// CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
125124
// CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
126125
// CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
@@ -270,3 +269,96 @@ func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %v
270269
}
271270

272271
} // end module
272+
273+
// -----
274+
275+
// Check that we can access i8 storage with i8 types available but without
276+
// 8-bit storage capabilities.
277+
module attributes {
278+
spirv.target_env = #spirv.target_env<
279+
#spirv.vce<v1.0, [Shader, Int64, Int8], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
280+
} {
281+
282+
// CHECK-LABEL: @load_i8
283+
// INDEX64-LABEL: @load_i8
284+
func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
285+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
286+
// CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
287+
// CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
288+
// CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
289+
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
290+
// CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
291+
// CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
292+
// CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
293+
// CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
294+
// CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
295+
// CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
296+
// CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
297+
// CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
298+
// CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
299+
// CHECK: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8
300+
// CHECK: return %[[CAST]] : i8
301+
302+
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
303+
// INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
304+
// INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
305+
// INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
306+
// INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
307+
// INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
308+
// INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
309+
// INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
310+
// INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
311+
// INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
312+
// INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
313+
// INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
314+
// INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
315+
// INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
316+
// INDEX64: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8
317+
// INDEX64: return %[[CAST]] : i8
318+
%0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
319+
return %0 : i8
320+
}
321+
322+
// CHECK-LABEL: @store_i8
323+
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
324+
// INDEX64-LABEL: @store_i8
325+
// INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
326+
func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) {
327+
// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
328+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
329+
// CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
330+
// CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
331+
// CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
332+
// CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
333+
// CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
334+
// CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
335+
// CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
336+
// CHECK: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
337+
// CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
338+
// CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
339+
// CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
340+
// CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
341+
// CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
342+
// CHECK: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
343+
344+
// INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
345+
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
346+
// INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
347+
// INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
348+
// INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
349+
// INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
350+
// INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
351+
// INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
352+
// INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
353+
// INDEX64: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
354+
// INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
355+
// INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
356+
// INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
357+
// INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
358+
// INDEX64: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
359+
// INDEX64: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
360+
memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
361+
return
362+
}
363+
364+
} // end module

0 commit comments

Comments
 (0)