Skip to content

Commit def16bc

Browse files
authored
[mlir][spirv] Retain nontemporal attribute when converting memref load/store (#82119)
Fixes #77156.
1 parent 0ec318e commit def16bc

File tree

2 files changed

+87
-44
lines changed

2 files changed

+87
-44
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -445,15 +445,30 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
445445
// LoadOp
446446
//===----------------------------------------------------------------------===//
447447

448-
using AlignmentRequirements =
449-
FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
448+
struct MemoryRequirements {
449+
spirv::MemoryAccessAttr memoryAccess;
450+
IntegerAttr alignment;
451+
};
450452

451453
/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
452454
/// any.
453-
static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
455+
static FailureOr<MemoryRequirements>
456+
calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
457+
MLIRContext *ctx = accessedPtr.getContext();
458+
459+
auto memoryAccess = spirv::MemoryAccess::None;
460+
if (isNontemporal) {
461+
memoryAccess = spirv::MemoryAccess::Nontemporal;
462+
}
463+
454464
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
455-
if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
456-
return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
465+
if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
466+
if (memoryAccess == spirv::MemoryAccess::None) {
467+
return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
468+
}
469+
return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
470+
IntegerAttr{}};
471+
}
457472

458473
// PhysicalStorageBuffers require the `Aligned` attribute.
459474
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
@@ -465,30 +480,32 @@ static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
465480
if (!sizeInBytes.has_value())
466481
return failure();
467482

468-
MLIRContext *ctx = accessedPtr.getContext();
469-
auto memAccessAttr =
470-
spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
483+
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
484+
auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
471485
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
472-
return std::pair{memAccessAttr, alignment};
486+
return MemoryRequirements{memAccessAttr, alignment};
473487
}
474488

475489
/// Given an accessed SPIR-V pointer and the original memref load/store
476490
/// `memAccess` op, calculates the alignment requirements, if any. Takes into
477491
/// account the alignment attributes applied to the load/store op.
478-
static AlignmentRequirements
479-
calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
480-
assert(memrefAccessOp);
481-
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
482-
"Bad op type");
483-
492+
template <class LoadOrStoreOp>
493+
static FailureOr<MemoryRequirements>
494+
calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
495+
static_assert(
496+
llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
497+
"Must be called on either memref::LoadOp or memref::StoreOp");
498+
499+
Operation *memrefAccessOp = loadOrStoreOp.getOperation();
484500
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
485501
spirv::attributeName<spirv::MemoryAccess>());
486502
auto memrefAlignment =
487503
memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
488504
if (memrefMemAccess && memrefAlignment)
489-
return std::pair{memrefMemAccess, memrefAlignment};
505+
return MemoryRequirements{memrefMemAccess, memrefAlignment};
490506

491-
return calculateRequiredAlignment(accessedPtr);
507+
return calculateMemoryRequirements(accessedPtr,
508+
loadOrStoreOp.getNontemporal());
492509
}
493510

494511
LogicalResult
@@ -538,13 +555,12 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
538555
// If the rewritten load op has the same bit width, use the loading value
539556
// directly.
540557
if (srcBits == dstBits) {
541-
AlignmentRequirements alignmentRequirements =
542-
calculateRequiredAlignment(accessChain, loadOp);
543-
if (failed(alignmentRequirements))
558+
auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
559+
if (failed(memoryRequirements))
544560
return rewriter.notifyMatchFailure(
545-
loadOp, "failed to determine alignment requirements");
561+
loadOp, "failed to determine memory requirements");
546562

547-
auto [memoryAccess, alignment] = *alignmentRequirements;
563+
auto [memoryAccess, alignment] = *memoryRequirements;
548564
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
549565
memoryAccess, alignment);
550566
if (isBool)
@@ -568,13 +584,12 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
568584
assert(accessChainOp.getIndices().size() == 2);
569585
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
570586
srcBits, dstBits, rewriter);
571-
AlignmentRequirements alignmentRequirements =
572-
calculateRequiredAlignment(adjustedPtr, loadOp);
573-
if (failed(alignmentRequirements))
587+
auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
588+
if (failed(memoryRequirements))
574589
return rewriter.notifyMatchFailure(
575-
loadOp, "failed to determine alignment requirements");
590+
loadOp, "failed to determine memory requirements");
576591

577-
auto [memoryAccess, alignment] = *alignmentRequirements;
592+
auto [memoryAccess, alignment] = *memoryRequirements;
578593
Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
579594
memoryAccess, alignment);
580595

@@ -623,13 +638,13 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
623638
if (!loadPtr)
624639
return failure();
625640

626-
AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
627-
if (failed(requiredAlignment))
641+
auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
642+
if (failed(memoryRequirements))
628643
return rewriter.notifyMatchFailure(
629-
loadOp, "failed to determine alignment requirements");
644+
loadOp, "failed to determine memory requirements");
630645

631-
auto [memAccessAttr, alignment] = *requiredAlignment;
632-
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
646+
auto [memoryAccess, alignment] = *memoryRequirements;
647+
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
633648
alignment);
634649
return success();
635650
}
@@ -689,18 +704,17 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
689704
assert(dstBits % srcBits == 0);
690705

691706
if (srcBits == dstBits) {
692-
AlignmentRequirements requiredAlignment =
693-
calculateRequiredAlignment(accessChain);
694-
if (failed(requiredAlignment))
707+
auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
708+
if (failed(memoryRequirements))
695709
return rewriter.notifyMatchFailure(
696-
storeOp, "failed to determine alignment requirements");
710+
storeOp, "failed to determine memory requirements");
697711

698-
auto [memAccessAttr, alignment] = *requiredAlignment;
712+
auto [memoryAccess, alignment] = *memoryRequirements;
699713
Value storeVal = adaptor.getValue();
700714
if (isBool)
701715
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
702716
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
703-
memAccessAttr, alignment);
717+
memoryAccess, alignment);
704718
return success();
705719
}
706720

@@ -847,15 +861,14 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
847861
if (!storePtr)
848862
return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
849863

850-
AlignmentRequirements requiredAlignment =
851-
calculateRequiredAlignment(storePtr, storeOp);
852-
if (failed(requiredAlignment))
864+
auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
865+
if (failed(memoryRequirements))
853866
return rewriter.notifyMatchFailure(
854-
storeOp, "failed to determine alignment requirements");
867+
storeOp, "failed to determine memory requirements");
855868

856-
auto [memAccessAttr, alignment] = *requiredAlignment;
869+
auto [memoryAccess, alignment] = *memoryRequirements;
857870
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
858-
storeOp, storePtr, adaptor.getValue(), memAccessAttr, alignment);
871+
storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
859872
return success();
860873
}
861874

mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,33 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
431431
}
432432

433433
}
434+
435+
// -----
436+
437+
// Check nontemporal attribute
438+
439+
module attributes {
440+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
441+
Shader,
442+
PhysicalStorageBufferAddresses
443+
], [
444+
SPV_KHR_storage_buffer_storage_class,
445+
SPV_KHR_physical_storage_buffer
446+
]>, #spirv.resource_limits<>>
447+
} {
448+
func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
449+
%0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
450+
// CHECK: spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
451+
memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
452+
// CHECK: spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
453+
return
454+
}
455+
456+
func.func @load_nontemporal_aligned(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
457+
%0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
458+
// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned|Nontemporal", 4] : f32
459+
memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
460+
// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned|Nontemporal", 4] : f32
461+
return
462+
}
463+
}

0 commit comments

Comments
 (0)