Skip to content

Commit 846103c

Browse files
sabaumaSpenser Bauman
andauthored
[mlir][memref] Unranked support for extract_aligned_pointer_as_index (#93908)
memref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass. %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32> %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index Patterns like this occur when bufferizing operations on unranked tensors. This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion. Co-authored-by: Spenser Bauman <sabauma@fastmail>
1 parent 24c6579 commit 846103c

File tree

4 files changed

+48
-4
lines changed

4 files changed

+48
-4
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def MemRef_ExtractAlignedPointerAsIndexOp :
892892
}];
893893

894894
let arguments = (ins
895-
AnyStridedMemRef:$source
895+
AnyRankedOrUnrankedMemRef:$source
896896
);
897897
let results = (outs Index:$aligned_pointer);
898898

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,10 +1590,26 @@ class ConvertExtractAlignedPointerAsIndex
15901590
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
15911591
OpAdaptor adaptor,
15921592
ConversionPatternRewriter &rewriter) const override {
1593-
MemRefDescriptor desc(adaptor.getSource());
1593+
BaseMemRefType sourceTy = extractOp.getSource().getType();
1594+
1595+
Value alignedPtr;
1596+
if (sourceTy.hasRank()) {
1597+
MemRefDescriptor desc(adaptor.getSource());
1598+
alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1599+
} else {
1600+
auto elementPtrTy = LLVM::LLVMPointerType::get(
1601+
rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1602+
1603+
UnrankedMemRefDescriptor desc(adaptor.getSource());
1604+
Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1605+
1606+
alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1607+
rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1608+
elementPtrTy);
1609+
}
1610+
15941611
rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1595-
extractOp, getTypeConverter()->getIndexType(),
1596-
desc.alignedPtr(rewriter, extractOp->getLoc()));
1612+
extractOp, getTypeConverter()->getIndexType(), alignedPtr);
15971613
return success();
15981614
}
15991615
};

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,21 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
598598

599599
// -----
600600

601+
// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
602+
func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
603+
%0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
604+
// CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
605+
// CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
606+
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
607+
// CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64
608+
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
609+
610+
// CHECK: return %[[R]] : index
611+
return %0: index
612+
}
613+
614+
// -----
615+
601616
// CHECK-LABEL: func @extract_strided_metadata(
602617
// CHECK-SAME: %[[ARG:.*]]: memref
603618
// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,19 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
899899

900900
// -----
901901

902+
// CHECK-LABEL: extract_aligned_pointer_as_index_of_unranked_source
903+
// CHECK-SAME: (%[[ARG0:.*]]: memref<*xf32>
904+
func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf32>) -> index {
905+
// CHECK: %[[I:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<*xf32> -> index
906+
// CHECK: return %[[I]]
907+
908+
%r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32>
909+
%i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index
910+
return %i : index
911+
}
912+
913+
// -----
914+
902915
// Check that we simplify collapse_shape into
903916
// reinterpret_cast(extract_strided_metadata) + <some math>
904917
//

0 commit comments

Comments
 (0)