Skip to content

[mlir][spirv][memref] Calculate alignment for PhysicalStorageBuffers #80243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 97 additions & 11 deletions mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/Debug.h"
#include <cassert>
#include <optional>

#define DEBUG_TYPE "memref-to-spirv-pattern"
Expand Down Expand Up @@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
// LoadOp
//===----------------------------------------------------------------------===//

using AlignmentRequirements =
FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;

/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
/// any.
static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};

// PhysicalStorageBuffers require the `Aligned` attribute.
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
if (!pointeeType)
return failure();

// For scalar types, the alignment is determined by their size.
std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
if (!sizeInBytes.has_value())
return failure();

MLIRContext *ctx = accessedPtr.getContext();
auto memAccessAttr =
spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
return std::pair{memAccessAttr, alignment};
}

/// Given an accessed SPIR-V pointer and the original memref load/store
/// `memAccess` op, calculates the alignment requirements, if any. Takes into
/// account the alignment attributes applied to the load/store op.
static AlignmentRequirements
calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert(memrefAccessOp);
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");

auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>());
auto memrefAlignment =
memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
if (memrefMemAccess && memrefAlignment)
return std::pair{memrefMemAccess, memrefAlignment};

return calculateRequiredAlignment(accessedPtr);
}

LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// If the rewritten load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
AlignmentRequirements alignmentRequirements =
calculateRequiredAlignment(accessChain, loadOp);
if (failed(alignmentRequirements))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");

auto [memoryAccess, alignment] = *alignmentRequirements;
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
memoryAccess, alignment);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
Expand All @@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
loc, dstType, adjustedPtr,
loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>()),
loadOp->getAttrOfType<IntegerAttr>("alignment"));
AlignmentRequirements alignmentRequirements =
calculateRequiredAlignment(adjustedPtr, loadOp);
if (failed(alignmentRequirements))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");

auto [memoryAccess, alignment] = *alignmentRequirements;
Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
memoryAccess, alignment);

// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
Expand Down Expand Up @@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto loadPtr = spirv::getElementPtr(
Value loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);

if (!loadPtr)
return failure();

rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");

auto [memAccessAttr, alignment] = *requiredAlignment;
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
alignment);
return success();
}

Expand Down Expand Up @@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
assert(dstBits % srcBits == 0);

if (srcBits == dstBits) {
AlignmentRequirements requiredAlignment =
calculateRequiredAlignment(accessChain);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
storeOp, "failed to determine alignment requirements");

auto [memAccessAttr, alignment] = *requiredAlignment;
Value storeVal = adaptor.getValue();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
memAccessAttr, alignment);
return success();
}

Expand Down Expand Up @@ -768,8 +847,15 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
if (!storePtr)
return rewriter.notifyMatchFailure(storeOp, "type conversion failed");

rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
adaptor.getValue());
AlignmentRequirements requiredAlignment =
calculateRequiredAlignment(storePtr, storeOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
storeOp, "failed to determine alignment requirements");

auto [memAccessAttr, alignment] = *requiredAlignment;
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
storeOp, storePtr, adaptor.getValue(), memAccessAttr, alignment);
return success();
}

Expand Down
55 changes: 51 additions & 4 deletions mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s

// Check that with proper compute and storage extensions, we don't need to
// perform special tricks.

module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0,
#spirv.vce<v1.5,
[
Shader, Int8, Int16, Int64, Float16, Float64,
StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
PhysicalStorageBufferAddresses
],
[SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
[SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
#spirv.resource_limits<>>
} {

// CHECK-LABEL: @load_store_zero_rank_float
Expand Down Expand Up @@ -119,6 +121,51 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
return
}

// CHECK-LABEL: @load_store_i32_physical
func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
%0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
return
}

// CHECK-LABEL: @load_store_i8_physical
func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
%0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
return
}

// CHECK-LABEL: @load_store_i1_physical
func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
%0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
return
}

// CHECK-LABEL: @load_store_f32_physical
func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
%0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
return
}

// CHECK-LABEL: @load_store_f16_physical
func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
%0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
return
}

} // end module

// -----
Expand Down