Skip to content

Commit 8fd0bce

Browse files
authored
[mlir][spirv][memref] Calculate alignment for PhysicalStorageBuffers (#80243)
The SPIR-V spec requires that memory accesses to `PhysicalStorageBuffer`s are annotated with appropriate alignment attributes [1]. Calculate these based on memref alignment attributes or scalar type sizes. [1] Otherwise spirv-val complains: ``` [VULKAN] ! Validation Error: [ VUID-VkShaderModuleCreateInfo-pCode-01379 ] | MessageID = 0x2a1bf17f | SPIR-V module not valid: [VUID-StandaloneSpirv-PhysicalStorageBuffer64-04708] Memory accesses with PhysicalStorageBuffer must use Aligned. %48 = OpLoad %float %47 ```
1 parent 4a653b4 commit 8fd0bce

File tree

2 files changed

+148
-15
lines changed

2 files changed

+148
-15
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 97 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212

1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/MemRef/IR/MemRef.h"
15+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
1516
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1617
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1718
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1819
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20+
#include "mlir/IR/BuiltinAttributes.h"
1921
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/MLIRContext.h"
23+
#include "mlir/IR/Visitors.h"
24+
#include "mlir/Support/LogicalResult.h"
2025
#include "llvm/Support/Debug.h"
26+
#include <cassert>
2127
#include <optional>
2228

2329
#define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
439445
// LoadOp
440446
//===----------------------------------------------------------------------===//
441447

448+
using AlignmentRequirements =
449+
FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
450+
451+
/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
452+
/// any.
453+
static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
454+
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
455+
if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
456+
return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
457+
458+
// PhysicalStorageBuffers require the `Aligned` attribute.
459+
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
460+
if (!pointeeType)
461+
return failure();
462+
463+
// For scalar types, the alignment is determined by their size.
464+
std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
465+
if (!sizeInBytes.has_value())
466+
return failure();
467+
468+
MLIRContext *ctx = accessedPtr.getContext();
469+
auto memAccessAttr =
470+
spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
471+
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
472+
return std::pair{memAccessAttr, alignment};
473+
}
474+
475+
/// Given an accessed SPIR-V pointer and the original memref load/store
476+
/// `memAccess` op, calculates the alignment requirements, if any. Takes into
477+
/// account the alignment attributes applied to the load/store op.
478+
static AlignmentRequirements
479+
calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
480+
assert(memrefAccessOp);
481+
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
482+
"Bad op type");
483+
484+
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
485+
spirv::attributeName<spirv::MemoryAccess>());
486+
auto memrefAlignment =
487+
memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
488+
if (memrefMemAccess && memrefAlignment)
489+
return std::pair{memrefMemAccess, memrefAlignment};
490+
491+
return calculateRequiredAlignment(accessedPtr);
492+
}
493+
442494
LogicalResult
443495
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
444496
ConversionPatternRewriter &rewriter) const {
@@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
486538
// If the rewritten load op has the same bit width, use the loading value
487539
// directly.
488540
if (srcBits == dstBits) {
489-
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
541+
AlignmentRequirements alignmentRequirements =
542+
calculateRequiredAlignment(accessChain, loadOp);
543+
if (failed(alignmentRequirements))
544+
return rewriter.notifyMatchFailure(
545+
loadOp, "failed to determine alignment requirements");
546+
547+
auto [memoryAccess, alignment] = *alignmentRequirements;
548+
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
549+
memoryAccess, alignment);
490550
if (isBool)
491551
loadVal = castIntNToBool(loc, loadVal, rewriter);
492552
rewriter.replaceOp(loadOp, loadVal);
@@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
508568
assert(accessChainOp.getIndices().size() == 2);
509569
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
510570
srcBits, dstBits, rewriter);
511-
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
512-
loc, dstType, adjustedPtr,
513-
loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
514-
spirv::attributeName<spirv::MemoryAccess>()),
515-
loadOp->getAttrOfType<IntegerAttr>("alignment"));
571+
AlignmentRequirements alignmentRequirements =
572+
calculateRequiredAlignment(adjustedPtr, loadOp);
573+
if (failed(alignmentRequirements))
574+
return rewriter.notifyMatchFailure(
575+
loadOp, "failed to determine alignment requirements");
576+
577+
auto [memoryAccess, alignment] = *alignmentRequirements;
578+
Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
579+
memoryAccess, alignment);
516580

517581
// Shift the bits to the rightmost.
518582
// ____XXXX________ -> ____________XXXX
@@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
552616
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
553617
if (memrefType.getElementType().isSignlessInteger())
554618
return failure();
555-
auto loadPtr = spirv::getElementPtr(
619+
Value loadPtr = spirv::getElementPtr(
556620
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
557621
adaptor.getIndices(), loadOp.getLoc(), rewriter);
558622

559623
if (!loadPtr)
560624
return failure();
561625

562-
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
626+
AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
627+
if (failed(requiredAlignment))
628+
return rewriter.notifyMatchFailure(
629+
loadOp, "failed to determine alignment requirements");
630+
631+
auto [memAccessAttr, alignment] = *requiredAlignment;
632+
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
633+
alignment);
563634
return success();
564635
}
565636

@@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
618689
assert(dstBits % srcBits == 0);
619690

620691
if (srcBits == dstBits) {
692+
AlignmentRequirements requiredAlignment =
693+
calculateRequiredAlignment(accessChain);
694+
if (failed(requiredAlignment))
695+
return rewriter.notifyMatchFailure(
696+
storeOp, "failed to determine alignment requirements");
697+
698+
auto [memAccessAttr, alignment] = *requiredAlignment;
621699
Value storeVal = adaptor.getValue();
622700
if (isBool)
623701
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
624-
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
702+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
703+
memAccessAttr, alignment);
625704
return success();
626705
}
627706

@@ -768,8 +847,15 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
768847
if (!storePtr)
769848
return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
770849

771-
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
772-
adaptor.getValue());
850+
AlignmentRequirements requiredAlignment =
851+
calculateRequiredAlignment(storePtr, storeOp);
852+
if (failed(requiredAlignment))
853+
return rewriter.notifyMatchFailure(
854+
storeOp, "failed to determine alignment requirements");
855+
856+
auto [memAccessAttr, alignment] = *requiredAlignment;
857+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
858+
storeOp, storePtr, adaptor.getValue(), memAccessAttr, alignment);
773859
return success();
774860
}
775861

mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
1+
// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
22

33
// Check that with proper compute and storage extensions, we don't need to
44
// perform special tricks.
55

66
module attributes {
77
spirv.target_env = #spirv.target_env<
8-
#spirv.vce<v1.0,
8+
#spirv.vce<v1.5,
99
[
1010
Shader, Int8, Int16, Int64, Float16, Float64,
1111
StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
12-
StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
12+
StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
13+
PhysicalStorageBufferAddresses
1314
],
14-
[SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
15+
[SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
16+
#spirv.resource_limits<>>
1517
} {
1618

1719
// CHECK-LABEL: @load_store_zero_rank_float
@@ -119,6 +121,51 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
119121
return
120122
}
121123

124+
// CHECK-LABEL: @load_store_i32_physical
125+
func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
126+
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
127+
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
128+
%0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
129+
memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
130+
return
131+
}
132+
133+
// CHECK-LABEL: @load_store_i8_physical
134+
func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
135+
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
136+
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
137+
%0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
138+
memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
139+
return
140+
}
141+
142+
// CHECK-LABEL: @load_store_i1_physical
143+
func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
144+
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
145+
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
146+
%0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
147+
memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
148+
return
149+
}
150+
151+
// CHECK-LABEL: @load_store_f32_physical
152+
func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
153+
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
154+
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
155+
%0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
156+
memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
157+
return
158+
}
159+
160+
// CHECK-LABEL: @load_store_f16_physical
161+
func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
162+
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
163+
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
164+
%0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
165+
memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
166+
return
167+
}
168+
122169
} // end module
123170

124171
// -----

0 commit comments

Comments
 (0)