-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir][memref] Unranked support for extract_aligned_pointer_as_index #93908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
memref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass. %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32> %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index Patterns like this occur when bufferizing operations on unranked tensors. This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Spenser Bauman (sabauma) Changesmemref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass.
Patterns like this occur when bufferizing operations on unranked tensors. This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion. Full diff: https://github.com/llvm/llvm-project/pull/93908.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 63e6ed059deb1..df40e7a17a15f 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -892,7 +892,7 @@ def MemRef_ExtractAlignedPointerAsIndexOp :
}];
let arguments = (ins
- AnyStridedMemRef:$source
+ AnyRankedOrUnrankedMemRef:$source
);
let results = (outs Index:$aligned_pointer);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2dc42f0a85e66..82c4b04656b33 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1590,10 +1590,26 @@ class ConvertExtractAlignedPointerAsIndex
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefDescriptor desc(adaptor.getSource());
+ BaseMemRefType sourceTy = extractOp.getSource().getType();
+
+ Value alignedPtr;
+ if (sourceTy.hasRank()) {
+ MemRefDescriptor desc(adaptor.getSource());
+ alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
+ } else {
+ auto elementPtrTy = LLVM::LLVMPointerType::get(
+ rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
+
+ UnrankedMemRefDescriptor desc(adaptor.getSource());
+ Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
+
+ alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
+ rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
+ elementPtrTy);
+ }
+
rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
- extractOp, getTypeConverter()->getIndexType(),
- desc.alignedPtr(rewriter, extractOp->getLoc()));
+ extractOp, getTypeConverter()->getIndexType(), alignedPtr);
return success();
}
};
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index baf9cfe610a5a..882804132e66d 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -598,6 +598,21 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
// -----
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
+func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
+ %0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
+ // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
+ // CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+ // CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
+ // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64
+ // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+ // CHECK: return %[[R]] : index
+ return %0: index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_strided_metadata(
// CHECK-SAME: %[[ARG:.*]]: memref
// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 3bd6b7c1fd791..d884ade319532 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -899,6 +899,19 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
// -----
+// CHECK-LABEL: extract_aligned_pointer_as_index_of_unranked_source
+// CHECK-SAME: (%[[ARG0:.*]]: memref<*xf32>
+func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf32>) -> index {
+ // CHECK: %[[I:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<*xf32> -> index
+ // CHECK: return %[[I]]
+
+ %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32>
+ %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index
+ return %i : index
+}
+
+// -----
+
// Check that we simplify collapse_shape into
// reinterpret_cast(extract_strided_metadata) + <some math>
//
|
@jpienaar @matthias-springer @nicolasvasilache Just a gentle ping for reviewers. |
…lvm#93908) memref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass. %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32> %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index Patterns like this occur when bufferizing operations on unranked tensors. This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion. Co-authored-by: Spenser Bauman <sabauma@fastmail>
memref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass.
Patterns like this occur when bufferizing operations on unranked tensors.
This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion.