Skip to content

[mlir][Ptr] Add the MemorySpaceAttrInterface interface and dependencies. #86870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,15 @@ set(LLVM_TARGET_DEFINITIONS PtrOps.td)
mlir_tablegen(PtrOpsAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=ptr)
mlir_tablegen(PtrOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ptr)
add_public_tablegen_target(MLIRPtrOpsAttributesIncGen)

set(LLVM_TARGET_DEFINITIONS MemorySpaceInterfaces.td)
mlir_tablegen(MemorySpaceInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(MemorySpaceInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(MemorySpaceAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(MemorySpaceAttrInterfaces.cpp.inc -gen-attr-interface-defs)
add_public_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen)

set(LLVM_TARGET_DEFINITIONS PtrOps.td)
mlir_tablegen(PtrOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(PtrOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRPtrOpsEnumsGen)
32 changes: 32 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===-- MemorySpaceInterfaces.h - ptr memory space interfaces ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the ptr dialect memory space interfaces.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H
#define MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"

namespace mlir {
class Operation;
namespace ptr {
enum class AtomicBinOp : uint64_t;
enum class AtomicOrdering : uint64_t;
} // namespace ptr
} // namespace mlir

#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.h.inc"

#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h.inc"

#endif // MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H
117 changes: 117 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//===-- MemorySpaceInterfaces.td - Memory space interfaces ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines memory space attribute interfaces.
//
//===----------------------------------------------------------------------===//

#ifndef PTR_MEMORYSPACEINTERFACES
#define PTR_MEMORYSPACEINTERFACES

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// Memory space attribute interface.
//===----------------------------------------------------------------------===//

def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
let description = [{
This interface defines a common API for interacting with the memory model of
a memory space and the operations in the pointer dialect.

Furthermore, this interface allows concepts such as read-only memory to be
adequately modeled and enforced.
}];
let cppNamespace = "::mlir::ptr";
let methods = [
InterfaceMethod<
/*desc=*/ [{
This method checks if it's valid to load a value from the memory space
with a specific type, alignment, and atomic ordering.
If `emitError` is non-null then the method is allowed to emit errors.
}],
/*returnType=*/ "::mlir::LogicalResult",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianmcg : nitpicking here but methods that starts with is are predicates and should return bool: LogicalResult is a about method that can fail instead. Can you update?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can update. But these methods can emit an error, remember they get a ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError parameter.

Admittedly, I would prefer these methods to return something like a SuccesOrDiagnostic, so users can have more control on how to handle the result. Or to decouple the diagnostic from the validity method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FailureOr<bool>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good immediate option, I'll change to that. But I think we are missing something similar to llvm::Expected, like:

template <typename T>
class DiagnosticOr : public std::optional<T> {
   operator LogicalResult () const;
   Diagnostic takeDiagnostic();
   void emitDiagnostic();
   T& getValue();
private:
  Diagnostic message;
};

Which would allow users to decide how to handle the diagnostic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're trying to be more efficient by avoiding formatting a diagnostic when the user does not want it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On success we shouldn't format a diagnostic, but I'm still not convinced about the approach.
I'll change to FailureOr<bool>, I'll revisit if I have a better idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On success we shouldn't format a diagnostic

I'm talking about the failing case where the user of the API would just discard the diagnostic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just remembered the emitError is always optional and the user of the API can opt to pass a default constructed object, in which case the method cannot produce errors. Therefore, it already handles the case of discarding the diagnostic in a efficient manner.

I'll switch to FailureOr<bool>.

/*methodName=*/ "isValidLoad",
/*args=*/ (ins "::mlir::Type":$type,
"::mlir::ptr::AtomicOrdering":$ordering,
"::mlir::IntegerAttr":$alignment,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
/*desc=*/ [{
This method checks if it's valid to store a value in the memory space
with a specific type, alignment, and atomic ordering.
If `emitError` is non-null then the method is allowed to emit errors.
}],
/*returnType=*/ "::mlir::LogicalResult",
/*methodName=*/ "isValidStore",
/*args=*/ (ins "::mlir::Type":$type,
"::mlir::ptr::AtomicOrdering":$ordering,
"::mlir::IntegerAttr":$alignment,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
/*desc=*/ [{
This method checks if it's valid to perform an atomic operation in the
memory space with a specific type, alignment, and atomic ordering.
If `emitError` is non-null then the method is allowed to emit errors.
}],
/*returnType=*/ "::mlir::LogicalResult",
/*methodName=*/ "isValidAtomicOp",
/*args=*/ (ins "::mlir::ptr::AtomicBinOp":$op,
"::mlir::Type":$type,
"::mlir::ptr::AtomicOrdering":$ordering,
"::mlir::IntegerAttr":$alignment,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
/*desc=*/ [{
This method checks if it's valid to perform an atomic exchange operation
in the memory space with a specific type, alignment, and atomic
orderings.
If `emitError` is non-null then the method is allowed to emit errors.
}],
/*returnType=*/ "::mlir::LogicalResult",
/*methodName=*/ "isValidAtomicXchg",
/*args=*/ (ins "::mlir::Type":$type,
"::mlir::ptr::AtomicOrdering":$successOrdering,
"::mlir::ptr::AtomicOrdering":$failureOrdering,
"::mlir::IntegerAttr":$alignment,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
/*desc=*/ [{
This method checks if it's valid to perform an `addrspacecast` op
in the memory space.
If `emitError` is non-null then the method is allowed to emit errors.
}],
/*returnType=*/ "::mlir::LogicalResult",
/*methodName=*/ "isValidAddrSpaceCast",
/*args=*/ (ins "::mlir::Type":$tgt,
"::mlir::Type":$src,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
/*desc=*/ [{
This method checks if it's valid to perform a `ptrtoint` or `inttoptr`
op in the memory space.
The first type is expected to be integer-like, while the second must be a
ptr-like type.
If `emitError` is non-null then the method is allowed to emit errors.
}],
/*returnType=*/ "::mlir::LogicalResult",
/*methodName=*/ "isValidPtrIntCast",
/*args=*/ (ins "::mlir::Type":$intLikeTy,
"::mlir::Type":$ptrLikeTy,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
];
}

#endif // PTR_MEMORYSPACEINTERFACES
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"

#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc"

#endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
let description = [{
The `ptr` type is an opaque pointer type. This type typically represents a
handle to an object in memory or target-dependent values like `nullptr`.
Pointers are optionally parameterized by a memory space.
Pointers are parameterized by a memory space.

Syntax:

Expand All @@ -54,14 +54,14 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
memory-space ::= attribute-value
```
}];
let parameters = (ins OptionalParameter<"Attribute">:$memorySpace);
let assemblyFormat = "(`<` $memorySpace^ `>`)?";
let parameters = (ins "MemorySpaceAttrInterface":$memorySpace);
let assemblyFormat = "`<` $memorySpace `>`";
let builders = [
TypeBuilder<(ins CArg<"Attribute", "nullptr">:$memorySpace), [{
return $_get($_ctxt, memorySpace);
TypeBuilderWithInferredContext<(ins
"MemorySpaceAttrInterface":$memorySpace), [{
return $_get(memorySpace.getContext(), memorySpace);
}]>
];
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
69 changes: 69 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//===-- PtrEnums.td - Ptr dialect enumerations -------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PTR_ENUMS
#define PTR_ENUMS

include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
// Atomic binary op enum attribute.
//===----------------------------------------------------------------------===//

def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0, "xchg">;
def AtomicBinOpAdd : I64EnumAttrCase<"add", 1, "add">;
def AtomicBinOpSub : I64EnumAttrCase<"sub", 2, "sub">;
def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3, "_and">;
def AtomicBinOpNand : I64EnumAttrCase<"nand", 4, "nand">;
def AtomicBinOpOr : I64EnumAttrCase<"_or", 5, "_or">;
def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6, "_xor">;
def AtomicBinOpMax : I64EnumAttrCase<"max", 7, "max">;
def AtomicBinOpMin : I64EnumAttrCase<"min", 8, "min">;
def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9, "umax">;
def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10, "umin">;
def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11, "fadd">;
def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12, "fsub">;
def AtomicBinOpFMax : I64EnumAttrCase<"fmax", 13, "fmax">;
def AtomicBinOpFMin : I64EnumAttrCase<"fmin", 14, "fmin">;
def AtomicBinOpUIncWrap : I64EnumAttrCase<"uinc_wrap", 15, "uinc_wrap">;
def AtomicBinOpUDecWrap : I64EnumAttrCase<"udec_wrap", 16, "udec_wrap">;

def AtomicBinOp : I64EnumAttr<
"AtomicBinOp",
"ptr.atomicrmw binary operations",
[AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd,
AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax,
AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd,
AtomicBinOpFSub, AtomicBinOpFMax, AtomicBinOpFMin, AtomicBinOpUIncWrap,
AtomicBinOpUDecWrap]> {
let cppNamespace = "::mlir::ptr";
}

//===----------------------------------------------------------------------===//
// Atomic ordering enum attribute.
//===----------------------------------------------------------------------===//

def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0, "not_atomic">;
def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1, "unordered">;
def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2, "monotonic">;
def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 3, "acquire">;
def AtomicOrderingRelease : I64EnumAttrCase<"release", 4, "release">;
def AtomicOrderingAcqRel : I64EnumAttrCase<"acq_rel", 5, "acq_rel">;
def AtomicOrderingSeqCst : I64EnumAttrCase<"seq_cst", 6, "seq_cst">;

def AtomicOrdering : I64EnumAttr<
"AtomicOrdering",
"Atomic ordering for LLVM's memory model",
[AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic,
AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcqRel,
AtomicOrderingSeqCst
]> {
let cppNamespace = "::mlir::ptr";
}

#endif // PTR_ENUMS
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

include "mlir/Dialect/Ptr/IR/PtrDialect.td"
include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
include "mlir/IR/OpAsmInterface.td"

#endif // PTR_OPS
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
#define MLIR_DIALECT_PTR_IR_PTRTYPES_H

#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ add_mlir_dialect_library(
DEPENDS
MLIRPtrOpsAttributesIncGen
MLIRPtrOpsIncGen

MLIRPtrOpsEnumsGen
MLIRPtrMemorySpaceInterfacesIncGen
LINK_LIBS
PUBLIC
MLIRIR
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ void PtrDialect::initialize() {
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"

#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"

#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"

#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"

Expand Down
31 changes: 18 additions & 13 deletions mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ constexpr const static unsigned kDefaultPointerAlignmentBits = 8;
/// Searches the data layout for the pointer spec, returns nullptr if it is not
/// found.
static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type,
Attribute defaultMemorySpace) {
MemorySpaceAttrInterface defaultMemorySpace) {
for (DataLayoutEntryInterface entry : params) {
if (!entry.isTypeEntry())
continue;
Expand All @@ -38,9 +38,11 @@ static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type,
return spec;
}
}
// If not found, and this is the pointer to the default memory space, assume
// 64-bit pointers.
if (type.getMemorySpace() == defaultMemorySpace)
// If not found, and this is the pointer to the default memory space or if
// `defaultMemorySpace` is null, assume 64-bit pointers. `defaultMemorySpace`
// might be null if the data layout doesn't define the default memory space.
if (type.getMemorySpace() == defaultMemorySpace ||
defaultMemorySpace == nullptr)
return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits,
kDefaultPointerAlignmentBits,
kDefaultPointerAlignmentBits, kDefaultPointerSizeBits);
Expand Down Expand Up @@ -93,44 +95,47 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,

uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace();
auto defaultMemorySpace = llvm::cast_if_present<MemorySpaceAttrInterface>(
dataLayout.getDefaultMemorySpace());
if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace))
return spec.getAbi() / kBitsInByte;

return dataLayout.getTypeABIAlignment(get(getContext(), defaultMemorySpace));
return dataLayout.getTypeABIAlignment(get(defaultMemorySpace));
}

std::optional<uint64_t>
PtrType::getIndexBitwidth(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace();
auto defaultMemorySpace = llvm::cast_if_present<MemorySpaceAttrInterface>(
dataLayout.getDefaultMemorySpace());
if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) {
return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize()
: spec.getIndex();
}

return dataLayout.getTypeIndexBitwidth(get(getContext(), defaultMemorySpace));
return dataLayout.getTypeIndexBitwidth(get(defaultMemorySpace));
}

llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace();
auto defaultMemorySpace = llvm::cast_if_present<MemorySpaceAttrInterface>(
dataLayout.getDefaultMemorySpace());
if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace))
return llvm::TypeSize::getFixed(spec.getSize());

// For other memory spaces, use the size of the pointer to the default memory
// space.
return dataLayout.getTypeSizeInBits(get(getContext(), defaultMemorySpace));
return dataLayout.getTypeSizeInBits(get(defaultMemorySpace));
}

uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
Attribute defaultMemorySpace = dataLayout.getDefaultMemorySpace();
auto defaultMemorySpace = llvm::cast_if_present<MemorySpaceAttrInterface>(
dataLayout.getDefaultMemorySpace());
if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace))
return spec.getPreferred() / kBitsInByte;

return dataLayout.getTypePreferredAlignment(
get(getContext(), defaultMemorySpace));
return dataLayout.getTypePreferredAlignment(get(defaultMemorySpace));
}

LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
Expand Down
Loading