Skip to content

[mlir][spirv] Retain nontemporal attribute when converting memref load/store #82119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 2, 2024

Conversation

agentcooper
Copy link
Contributor

Fixes #77156.

@llvmbot
Copy link
Member

llvmbot commented Feb 17, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Artem Tyurin (agentcooper)

Changes

Fixes #77156.


Full diff: https://github.com/llvm/llvm-project/pull/82119.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+10-1)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+14)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
   assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
          "Bad op type");
 
+  auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+  if (nontemporalAttr && nontemporalAttr.getValue()) {
+    return std::pair{
+        spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+                                     spirv::MemoryAccess::Nontemporal),
+        IntegerAttr{}};
+  }
+
   auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
       spirv::attributeName<spirv::MemoryAccess>());
   auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!loadPtr)
     return failure();
 
-  AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+  AlignmentRequirements requiredAlignment =
+      calculateRequiredAlignment(loadPtr, loadOp);
   if (failed(requiredAlignment))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 }
 
 }
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+    %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+    memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+    return
+  }
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 17, 2024

@llvm/pr-subscribers-mlir

Author: Artem Tyurin (agentcooper)

Changes

Fixes #77156.


Full diff: https://github.com/llvm/llvm-project/pull/82119.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+10-1)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+14)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
   assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
          "Bad op type");
 
+  auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+  if (nontemporalAttr && nontemporalAttr.getValue()) {
+    return std::pair{
+        spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+                                     spirv::MemoryAccess::Nontemporal),
+        IntegerAttr{}};
+  }
+
   auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
       spirv::attributeName<spirv::MemoryAccess>());
   auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!loadPtr)
     return failure();
 
-  AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+  AlignmentRequirements requiredAlignment =
+      calculateRequiredAlignment(loadPtr, loadOp);
   if (failed(requiredAlignment))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 }
 
 }
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+    %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+    memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+    return
+  }
+}

@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");

auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nontemproal is an attribute codified in the memref load/store op's definition. You can have direct accessors like loadOp.getNontemproal().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this function takes Operation * to allow both load and store, this seems to be the most convenient way to access the attribute. I can extract "nontemporal" string as a constant if you can suggest a good place for it. Or would you prefer casting to memref::LoadOp/memref::StoreOp instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make this a templated function so that we can use the accessor directly.

@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!loadPtr)
return failure();

AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
AlignmentRequirements requiredAlignment =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is calculating the alignement; non termporal is not about the alignment. So I'm not sure we should put it here?

Copy link
Contributor Author

@agentcooper agentcooper Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using AlignmentRequirements = FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;

It does calculate both memory access and alignment and nontemporal is spirv::MemoryAccess::Nontemporal, so I think it fits.

We can rename it to MemoryRequirements maybe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah renaming it would be good. As the function name shows it's about alignment calculation right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");

auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
if (nontemporalAttr && nontemporalAttr.getValue()) {
return std::pair{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the Nontemporal is a bit in the bitfield--it should be or-ed together with the Aligned bit calculuated in the below--as it's implemented right now, this breaks the logic for alignment calculation.

We need to have a test on this (nontemporal + aligned) to check it.

@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");

auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make this a templated function so that we can use the accessor directly.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuhar kuhar self-requested a review February 27, 2024 18:43
@antiagainst
Copy link
Member

Nice. Thanks for fixing this and addressing all the comments! :)

@antiagainst antiagainst merged commit def16bc into llvm:main Mar 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Retain nontemporal information when converting memref load/store
4 participants