Skip to content

Commit f159774

Browse files
fabianmcgjoker-eph
andauthored
[mlir][core|ptr] Add PtrLikeTypeInterface and casting ops to the ptr dialect (#137469)
This patch adds the `PtrLikeTypeInterface` type interface to identify pointer-like types. This interface is defined as: ``` A ptr-like type represents an object storing a memory address. This object is constituted by: - A memory address called the base pointer. This pointer is treated as a bag of bits without any assumed structure. The bit-width of the base pointer must be a compile-time constant. However, the bit-width may remain opaque or unavailable during transformations that do not depend on the base pointer. Finally, it is considered indivisible in the sense that as a `PtrLikeTypeInterface` value, it has no metadata. - Optional metadata about the pointer. For example, the size of the memory region associated with the pointer. Furthermore, all ptr-like types have two properties: - The memory space associated with the address held by the pointer. - An optional element type. If the element type is not specified, the pointer is considered opaque. ``` This patch adds this interface to `!ptr.ptr` and the `memref` type. Furthermore, this patch adds necessary ops and type to handle casting between `!ptr.ptr` and ptr-like types. First, it defines the `!ptr.ptr_metadata` type. An opaque type to represent the metadata of a ptr-like type. The rationale behind adding this type, is that at high-level the metadata of a type like `memref` cannot be specified, as its structure is tied to its lowering. The `ptr.get_metadata` operation was added to extract the opaque pointer metadata. The concrete structure of the metadata is only known when the op is lowered. Finally, this patch adds the `ptr.from_ptr` and `ptr.to_ptr` operations. Allowing to cast back and forth between `!ptr.ptr` and ptr-like types. ```mlir func.func @func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> { %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space> %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space> %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space> return %res : memref<f32, #ptr.generic_space> } ``` It's future work to replace and remove the `bare-ptr-convention` through the use of these ops. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 925dbc7 commit f159774

File tree

11 files changed

+458
-3
lines changed

11 files changed

+458
-3
lines changed

mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
3737

3838
def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
3939
MemRefElementTypeInterface,
40+
PtrLikeTypeInterface,
4041
VectorElementTypeInterface,
4142
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
4243
"areCompatible", "getIndexBitwidth", "verifyEntries",
@@ -63,6 +64,55 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
6364
return $_get(memorySpace.getContext(), memorySpace);
6465
}]>
6566
];
67+
let extraClassDeclaration = [{
68+
// `PtrLikeTypeInterface` interface methods.
69+
/// Returns `Type()` as this pointer type is opaque.
70+
Type getElementType() const {
71+
return Type();
72+
}
73+
/// Clones the pointer with specified memory space or returns failure
74+
/// if an `elementType` was specified or if the memory space doesn't
75+
/// implement `MemorySpaceAttrInterface`.
76+
FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
77+
std::optional<Type> elementType) const {
78+
if (elementType)
79+
return failure();
80+
if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
81+
return cast<PtrLikeTypeInterface>(get(ms));
82+
return failure();
83+
}
84+
/// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
85+
bool hasPtrMetadata() const {
86+
return false;
87+
}
88+
}];
89+
}
90+
91+
def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
92+
let summary = "Pointer metadata type";
93+
let description = [{
94+
The `ptr_metadata` type represents an opaque-view of the metadata associated
95+
with a `ptr-like` object type.
96+
97+
Note: It's a verification error to construct a `ptr_metadata` type using a
98+
`ptr-like` type with no metadata.
99+
100+
Example:
101+
102+
```mlir
103+
// The metadata associated with a `memref` type.
104+
!ptr.ptr_metadata<memref<f32>>
105+
```
106+
}];
107+
let parameters = (ins "PtrLikeTypeInterface":$type);
108+
let assemblyFormat = "`<` $type `>`";
109+
let builders = [
110+
TypeBuilderWithInferredContext<(ins
111+
"PtrLikeTypeInterface":$ptrLike), [{
112+
return $_get(ptrLike.getContext(), ptrLike);
113+
}]>
114+
];
115+
let genVerifyDecl = 1;
66116
}
67117

68118
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,72 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/Interfaces/ViewLikeInterface.td"
1818
include "mlir/IR/OpAsmInterface.td"
1919

20+
//===----------------------------------------------------------------------===//
21+
// FromPtrOp
22+
//===----------------------------------------------------------------------===//
23+
24+
def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
25+
Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
26+
"PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
27+
]> {
28+
let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
29+
let description = [{
30+
The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
31+
important to note that:
32+
- The ptr-like object cannot be a `!ptr.ptr`.
33+
- The memory-space of both the `ptr` and ptr-like object must match.
34+
- The cast is Pure (no UB and side-effect free).
35+
36+
The optional `metadata` operand exists to provide any ptr-like metadata
37+
that might be required to perform the cast.
38+
39+
Example:
40+
41+
```mlir
42+
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
43+
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
44+
45+
// Cast the `%ptr` to a memref without utilizing metadata.
46+
%memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
47+
```
48+
}];
49+
50+
let arguments = (ins Ptr_PtrType:$ptr, Optional<Ptr_PtrMetadata>:$metadata);
51+
let results = (outs PtrLikeTypeInterface:$result);
52+
let assemblyFormat = [{
53+
$ptr (`metadata` $metadata^)? attr-dict `:` type($ptr) `->` type($result)
54+
}];
55+
let hasFolder = 1;
56+
let hasVerifier = 1;
57+
}
58+
59+
//===----------------------------------------------------------------------===//
60+
// GetMetadataOp
61+
//===----------------------------------------------------------------------===//
62+
63+
def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
64+
Pure, TypesMatchWith<"metadata type", "ptr", "result",
65+
"PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
66+
]> {
67+
let summary = "SSA value representing pointer metadata.";
68+
let description = [{
69+
The `get_metadata` operation produces an opaque value that encodes the
70+
metadata of the ptr-like type.
71+
72+
Example:
73+
74+
```mlir
75+
%metadata = ptr.get_metadata %memref : memref<?x?xf32>
76+
```
77+
}];
78+
79+
let arguments = (ins PtrLikeTypeInterface:$ptr);
80+
let results = (outs Ptr_PtrMetadata:$result);
81+
let assemblyFormat = [{
82+
$ptr attr-dict `:` type($ptr)
83+
}];
84+
}
85+
2086
//===----------------------------------------------------------------------===//
2187
// PtrAddOp
2288
//===----------------------------------------------------------------------===//
@@ -32,8 +98,8 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
3298
Example:
3399

34100
```mlir
35-
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32
36-
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32
101+
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
102+
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
37103
```
38104
}];
39105

@@ -52,6 +118,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
52118
}];
53119
}
54120

121+
//===----------------------------------------------------------------------===//
122+
// ToPtrOp
123+
//===----------------------------------------------------------------------===//
124+
125+
def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
126+
let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
127+
let description = [{
128+
The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
129+
important to note that:
130+
- The ptr-like object cannot be a `!ptr.ptr`.
131+
- The memory-space of both the `ptr` and ptr-like object must match.
132+
- The cast is side-effect free.
133+
134+
Example:
135+
136+
```mlir
137+
%ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
138+
%ptr1 = ptr.to_ptr %memref : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
139+
```
140+
}];
141+
142+
let arguments = (ins PtrLikeTypeInterface:$ptr);
143+
let results = (outs Ptr_PtrType:$result);
144+
let assemblyFormat = [{
145+
$ptr attr-dict `:` type($ptr) `->` type($result)
146+
}];
147+
let hasFolder = 1;
148+
let hasVerifier = 1;
149+
}
150+
55151
//===----------------------------------------------------------------------===//
56152
// TypeOffsetOp
57153
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,59 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
110110
}];
111111
}
112112

113+
//===----------------------------------------------------------------------===//
114+
// PtrLikeTypeInterface
115+
//===----------------------------------------------------------------------===//
116+
117+
def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
118+
let cppNamespace = "::mlir";
119+
let description = [{
120+
A ptr-like type represents an object storing a memory address. This object
121+
is constituted by:
122+
- A memory address called the base pointer. This pointer is treated as a
123+
bag of bits without any assumed structure. The bit-width of the base
124+
pointer must be a compile-time constant. However, the bit-width may remain
125+
opaque or unavailable during transformations that do not depend on the
126+
base pointer. Finally, it is considered indivisible in the sense that as
127+
a `PtrLikeTypeInterface` value, it has no metadata.
128+
- Optional metadata about the pointer. For example, the size of the memory
129+
region associated with the pointer.
130+
131+
Furthermore, all ptr-like types have two properties:
132+
- The memory space associated with the address held by the pointer.
133+
- An optional element type. If the element type is not specified, the
134+
pointer is considered opaque.
135+
}];
136+
let methods = [
137+
InterfaceMethod<[{
138+
Returns the memory space of this ptr-like type.
139+
}],
140+
"::mlir::Attribute", "getMemorySpace">,
141+
InterfaceMethod<[{
142+
Returns the element type of this ptr-like type. Note: this method can
143+
return `::mlir::Type()`, in which case the pointer is considered opaque.
144+
}],
145+
"::mlir::Type", "getElementType">,
146+
InterfaceMethod<[{
147+
Returns whether this ptr-like type has non-empty metadata.
148+
}],
149+
"bool", "hasPtrMetadata">,
150+
InterfaceMethod<[{
151+
Returns a clone of this type with the given memory space and element type,
152+
or `failure` if the type cannot be cloned with the specified arguments.
153+
If the pointer is opaque and `elementType` is not `std::nullopt` the
154+
method will return `failure`.
155+
156+
If no `elementType` is provided and ptr is not opaque, the `elementType`
157+
of this type is used.
158+
}],
159+
"::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
160+
"::mlir::Attribute":$memorySpace,
161+
"::std::optional<::mlir::Type>":$elementType
162+
)>
163+
];
164+
}
165+
113166
//===----------------------------------------------------------------------===//
114167
// ShapedType
115168
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
9999
/// Note: This class attaches the ShapedType trait to act as a mixin to
100100
/// provide many useful utility functions. This inheritance has no effect
101101
/// on derived memref types.
102-
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
102+
class BaseMemRefType : public Type,
103+
public PtrLikeTypeInterface::Trait<BaseMemRefType>,
104+
public ShapedType::Trait<BaseMemRefType> {
103105
public:
104106
using Type::Type;
105107

@@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
117119
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
118120
Type elementType) const;
119121

122+
/// Clone this type with the given memory space and element type. If the
123+
/// provided element type is `std::nullopt`, the current element type of the
124+
/// type is used.
125+
FailureOr<PtrLikeTypeInterface>
126+
clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;
127+
120128
// Make sure that base class overloads are visible.
121129
using ShapedType::Trait<BaseMemRefType>::clone;
122130

@@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
141149
/// New `Attribute getMemorySpace()` method should be used instead.
142150
unsigned getMemorySpaceAsInt() const;
143151

152+
/// Returns that this ptr-like object has non-empty ptr metadata.
153+
bool hasPtrMetadata() const { return true; }
154+
144155
/// Allow implicit conversion to ShapedType.
145156
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
157+
158+
/// Allow implicit conversion to PtrLikeTypeInterface.
159+
operator PtrLikeTypeInterface() const {
160+
return llvm::cast<PtrLikeTypeInterface>(*this);
161+
}
146162
};
147163

148164
} // namespace mlir

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
562562
//===----------------------------------------------------------------------===//
563563

564564
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
565+
PtrLikeTypeInterface,
565566
ShapedTypeInterface
566567
], "BaseMemRefType"> {
567568
let summary = "Shaped reference to a region of memory";
@@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
11431144
//===----------------------------------------------------------------------===//
11441145

11451146
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
1147+
PtrLikeTypeInterface,
11461148
ShapedTypeInterface
11471149
], "BaseMemRefType"> {
11481150
let summary = "Shaped reference, with unknown rank, to a region of memory";

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,52 @@ void PtrDialect::initialize() {
4141
>();
4242
}
4343

44+
//===----------------------------------------------------------------------===//
45+
// FromPtrOp
46+
//===----------------------------------------------------------------------===//
47+
48+
OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
49+
// Fold the pattern:
50+
// %ptr = ptr.to_ptr %v : type -> ptr
51+
// (%mda = ptr.get_metadata %v : type)?
52+
// %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
53+
// To:
54+
// %val -> %v
55+
Value ptrLike;
56+
FromPtrOp fromPtr = *this;
57+
while (fromPtr != nullptr) {
58+
auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
59+
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
60+
// different.
61+
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
62+
return ptrLike;
63+
Value md = fromPtr.getMetadata();
64+
// If the type has trivial metadata fold.
65+
if (!fromPtr.getType().hasPtrMetadata()) {
66+
ptrLike = toPtr.getPtr();
67+
} else if (md) {
68+
// Fold if the metadata can be verified to be equal.
69+
if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
70+
mdOp && mdOp.getPtr() == toPtr.getPtr())
71+
ptrLike = toPtr.getPtr();
72+
}
73+
// Check for a sequence of casts.
74+
fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
75+
: nullptr);
76+
}
77+
return ptrLike;
78+
}
79+
80+
LogicalResult FromPtrOp::verify() {
81+
if (isa<PtrType>(getType()))
82+
return emitError() << "the result type cannot be `!ptr.ptr`";
83+
if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
84+
return emitError()
85+
<< "expected the input and output to have the same memory space";
86+
}
87+
return success();
88+
}
89+
4490
//===----------------------------------------------------------------------===//
4591
// PtrAddOp
4692
//===----------------------------------------------------------------------===//
@@ -55,6 +101,40 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
55101
return nullptr;
56102
}
57103

104+
//===----------------------------------------------------------------------===//
105+
// ToPtrOp
106+
//===----------------------------------------------------------------------===//
107+
108+
OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
109+
// Fold the pattern:
110+
// %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
111+
// %ptr = ptr.to_ptr %val : type -> ptr
112+
// To:
113+
// %ptr -> %p
114+
Value ptr;
115+
ToPtrOp toPtr = *this;
116+
while (toPtr != nullptr) {
117+
auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
118+
// Cannot fold if it's not a `from_ptr` op.
119+
if (!fromPtr)
120+
return ptr;
121+
ptr = fromPtr.getPtr();
122+
// Check for chains of casts.
123+
toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
124+
}
125+
return ptr;
126+
}
127+
128+
LogicalResult ToPtrOp::verify() {
129+
if (isa<PtrType>(getPtr().getType()))
130+
return emitError() << "the input value cannot be of type `!ptr.ptr`";
131+
if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
132+
return emitError()
133+
<< "expected the input and output to have the same memory space";
134+
}
135+
return success();
136+
}
137+
58138
//===----------------------------------------------------------------------===//
59139
// TypeOffsetOp
60140
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)