-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Retain nontemporal attribute when converting memref load/store #82119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Retain nontemporal attribute when converting memref load/store #82119
Conversation
@llvm/pr-subscribers-mlir-spirv Author: Artem Tyurin (agentcooper) ChangesFixes #77156. Full diff: https://github.com/llvm/llvm-project/pull/82119.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");
+ auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+ if (nontemporalAttr && nontemporalAttr.getValue()) {
+ return std::pair{
+ spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+ spirv::MemoryAccess::Nontemporal),
+ IntegerAttr{}};
+ }
+
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>());
auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!loadPtr)
return failure();
- AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(loadPtr, loadOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
}
}
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+ func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+ %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+ memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+ return
+ }
+}
|
@llvm/pr-subscribers-mlir Author: Artem Tyurin (agentcooper) ChangesFixes #77156. Full diff: https://github.com/llvm/llvm-project/pull/82119.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");
+ auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+ if (nontemporalAttr && nontemporalAttr.getValue()) {
+ return std::pair{
+ spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+ spirv::MemoryAccess::Nontemporal),
+ IntegerAttr{}};
+ }
+
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>());
auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!loadPtr)
return failure();
- AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(loadPtr, loadOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
}
}
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+ func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+ %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+ memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+ return
+ }
+}
|
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) { | |||
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) && | |||
"Bad op type"); | |||
|
|||
auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nontemproal
is an attribute codified in the memref load/store op's definition. You can have direct accessors like loadOp.getNontemproal()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this function takes Operation *
to allow both load and store, this seems to be the most convenient way to access the attribute. I can extract "nontemporal"
string as a constant if you can suggest a good place for it. Or would you prefer casting to memref::LoadOp
/memref::StoreOp
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make this a templated function so that we can use the accessor directly.
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, | |||
if (!loadPtr) | |||
return failure(); | |||
|
|||
AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr); | |||
AlignmentRequirements requiredAlignment = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is calculating the alignement; non termporal is not about the alignment. So I'm not sure we should put it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using AlignmentRequirements = FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
It does calculate both memory access and alignment and nontemporal is spirv::MemoryAccess::Nontemporal
, so I think it fits.
We can rename it to MemoryRequirements
maybe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah renaming it would be good. As the function name shows it's about alignment calculation right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done 👍
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) { | |||
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) && | |||
"Bad op type"); | |||
|
|||
auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal"); | |||
if (nontemporalAttr && nontemporalAttr.getValue()) { | |||
return std::pair{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the Nontemporal
is a bit in the bitfield--it should be or-ed together with the Aligned
bit calculuated in the below--as it's implemented right now, this breaks the logic for alignment calculation.
We need to have a test on this (nontemporal + aligned) to check it.
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) { | |||
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) && | |||
"Bad op type"); | |||
|
|||
auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make this a templated function so that we can use the accessor directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Nice. Thanks for fixing this and addressing all the comments! :) |
Fixes #77156.