Skip to content

[mlir][memref] Unranked support for extract_aligned_pointer_as_index #93908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2024

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented May 31, 2024

memref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass.

%r = memref.reinterpret_cast %arg0 to
    offset: [0],
    sizes: [],
    strides: [] : memref<*xf32> to memref<f32>

%i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index

Patterns like this occur when bufferizing operations on unranked tensors.

This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion.

memref.extract_aligned_pointer_as_index currently does not support
unranked inputs. This lack of support interferes with the folding
operations in the expand-strided-metadata pass.

    %r = memref.reinterpret_cast %arg0 to
        offset: [0],
        sizes: [],
        strides: [] : memref<*xf32> to memref<f32>

    %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index

Patterns like this occur when bufferizing operations on unranked
tensors.

This change modifies the extract_aligned_pointer_as_index operation to
support unranked inputs with corresponding support in the MemRef->LLVM
conversion.
@llvmbot
Copy link
Member

llvmbot commented May 31, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Spenser Bauman (sabauma)

Changes

memref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass.

%r = memref.reinterpret_cast %arg0 to
    offset: [0],
    sizes: [],
    strides: [] : memref&lt;*xf32&gt; to memref&lt;f32&gt;

%i = memref.extract_aligned_pointer_as_index %r : memref&lt;f32&gt; -&gt; index

Patterns like this occur when bufferizing operations on unranked tensors.

This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion.


Full diff: https://github.com/llvm/llvm-project/pull/93908.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+1-1)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+19-3)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+15)
  • (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+13)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 63e6ed059deb1..df40e7a17a15f 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -892,7 +892,7 @@ def MemRef_ExtractAlignedPointerAsIndexOp :
   }];
 
   let arguments = (ins
-    AnyStridedMemRef:$source
+    AnyRankedOrUnrankedMemRef:$source
   );
   let results = (outs Index:$aligned_pointer);
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2dc42f0a85e66..82c4b04656b33 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1590,10 +1590,26 @@ class ConvertExtractAlignedPointerAsIndex
   matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
                   OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefDescriptor desc(adaptor.getSource());
+    BaseMemRefType sourceTy = extractOp.getSource().getType();
+
+    Value alignedPtr;
+    if (sourceTy.hasRank()) {
+      MemRefDescriptor desc(adaptor.getSource());
+      alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
+    } else {
+      auto elementPtrTy = LLVM::LLVMPointerType::get(
+          rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
+
+      UnrankedMemRefDescriptor desc(adaptor.getSource());
+      Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
+
+      alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
+          rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
+          elementPtrTy);
+    }
+
     rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
-        extractOp, getTypeConverter()->getIndexType(),
-        desc.alignedPtr(rewriter, extractOp->getLoc()));
+        extractOp, getTypeConverter()->getIndexType(), alignedPtr);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index baf9cfe610a5a..882804132e66d 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -598,6 +598,21 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
 
 // -----
 
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
+func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
+  // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> 
+  // CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+  // CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
+  // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R]] : index
+  return %0: index
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_strided_metadata(
 // CHECK-SAME: %[[ARG:.*]]: memref
 // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 3bd6b7c1fd791..d884ade319532 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -899,6 +899,19 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
 
 // -----
 
+// CHECK-LABEL: extract_aligned_pointer_as_index_of_unranked_source
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<*xf32>
+func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf32>) -> index {
+  // CHECK: %[[I:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<*xf32> -> index
+  // CHECK: return %[[I]]
+
+  %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32>
+  %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index
+  return %i : index
+}
+
+// -----
+
 // Check that we simplify collapse_shape into
 // reinterpret_cast(extract_strided_metadata) + <some math>
 //

@sabauma sabauma changed the title Unranked support for extract_aligned_pointer_as_index [mlir][memref] Unranked support for extract_aligned_pointer_as_index May 31, 2024
@sabauma
Copy link
Contributor Author

sabauma commented Jun 12, 2024

@jpienaar @matthias-springer @nicolasvasilache Just a gentle ping for reviewers.

@sabauma sabauma merged commit 846103c into llvm:main Jun 13, 2024
10 checks passed
EthanLuisMcDonough pushed a commit to EthanLuisMcDonough/llvm-project that referenced this pull request Aug 13, 2024
…lvm#93908)

memref.extract_aligned_pointer_as_index currently does not support
unranked inputs. This lack of support interferes with the folding
operations in the expand-strided-metadata pass.

    %r = memref.reinterpret_cast %arg0 to
        offset: [0],
        sizes: [],
        strides: [] : memref<*xf32> to memref<f32>

%i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index

Patterns like this occur when bufferizing operations on unranked
tensors.

This change modifies the extract_aligned_pointer_as_index operation to
support unranked inputs with corresponding support in the MemRef->LLVM
conversion.

Co-authored-by: Spenser Bauman <sabauma@fastmail>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants