Skip to content

[mlir][ptr] Add the ptradd and type_offset ops, and generic_space attr #136434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#define PTR_ATTRDEFS

include "mlir/Dialect/Ptr/IR/PtrDialect.td"
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"

// All of the attributes will extend this class.
class Ptr_Attr<string name, string attrMnemonic,
Expand All @@ -20,6 +22,31 @@ class Ptr_Attr<string name, string attrMnemonic,
let mnemonic = attrMnemonic;
}

//===----------------------------------------------------------------------===//
// GenericSpaceAttr
//===----------------------------------------------------------------------===//

def Ptr_GenericSpaceAttr :
Ptr_Attr<"GenericSpace", "generic_space", [
DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>
]> {
let summary = "Generic memory space";
let description = [{
The `generic_space` attribute defines a memory space attribute with the
following properties:
- Load and store operations are always valid, regardless of the type.
- Atomic operations are always valid, regardless of the type.
- Cast operations to `generic_space` are always valid.

Example:

```mlir
#ptr.generic_space
```
}];
let assemblyFormat = "";
}

//===----------------------------------------------------------------------===//
// SpecAttr
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
#ifndef MLIR_DIALECT_PTR_IR_PTRATTRS_H
#define MLIR_DIALECT_PTR_IR_PTRATTRS_H

#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "llvm/Support/TypeSize.h"

#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,15 @@ def AtomicOrdering : I64EnumAttr<
let cppNamespace = "::mlir::ptr";
}

//===----------------------------------------------------------------------===//
// Ptr add flags enum properties.
//===----------------------------------------------------------------------===//

def Ptr_PtrAddFlags : I32Enum<"PtrAddFlags", "Pointer add flags", [
I32EnumCase<"none", 0>, I32EnumCase<"nusw", 1>, I32EnumCase<"nuw", 2>,
I32EnumCase<"inbounds", 3>
]> {
let cppNamespace = "::mlir::ptr";
}

#endif // PTR_ENUMS
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

#define GET_OP_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOps.h.inc"
Expand Down
74 changes: 74 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,81 @@

include "mlir/Dialect/Ptr/IR/PtrDialect.td"
include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
include "mlir/Dialect/Ptr/IR/PtrEnums.td"
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"

//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//

def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
]> {
let summary = "Pointer add operation";
let description = [{
The `ptr_add` operation adds an integer offset to a pointer to produce a new
pointer. The input and output pointer types are always the same.

Example:

```mlir
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32
```
}];

let arguments = (ins
Ptr_PtrType:$base,
AnySignlessIntegerOrIndex:$offset,
DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
let results = (outs Ptr_PtrType:$result);
let assemblyFormat = [{
($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
}];
let hasFolder = 1;
let extraClassDeclaration = [{
/// `ViewLikeOp::getViewSource` method.
Value getViewSource() { return getBase(); }
}];
}

//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//

def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> {
let summary = "Type offset operation";
let description = [{
The `type_offset` operation produces an int or index-typed SSA value
equal to a target-specific constant representing the offset of a single
element of the given type.

Example:

```mlir
// Return the offset between two f32 stored in memory
%0 = ptr.type_offset f32 : index
// Return the offset between two memref descriptors stored in memory
%1 = ptr.type_offset memref<12 x f64> : i32
```
}];

let arguments = (ins TypeAttr:$elementType);
let results = (outs AnySignlessIntegerOrIndex:$result);
let builders = [
OpBuilder<(ins "Type":$elementType)>
];
let assemblyFormat = [{
$elementType attr-dict `:` type($result)
}];
let extraClassDeclaration = [{
/// Returns the type offset according to `layout`. If `layout` is `nullopt`
/// the nearest layout the op will be used for the computation.
llvm::TypeSize getTypeSize(std::optional<DataLayout> layout = std::nullopt);
}];
}

#endif // PTR_OPS
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ add_mlir_dialect_library(
MLIRIR
MLIRDataLayoutInterfaces
MLIRMemorySlotInterfaces
MLIRViewLikeInterface
)
46 changes: 46 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,59 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::ptr;

constexpr const static unsigned kBitsInByte = 8;

//===----------------------------------------------------------------------===//
// GenericSpaceAttr
//===----------------------------------------------------------------------===//

LogicalResult GenericSpaceAttr::isValidLoad(
Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment,
function_ref<InFlightDiagnostic()> emitError) const {
return success();
}

LogicalResult GenericSpaceAttr::isValidStore(
Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment,
function_ref<InFlightDiagnostic()> emitError) const {
return success();
}

LogicalResult GenericSpaceAttr::isValidAtomicOp(
ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const {
return success();
}

LogicalResult GenericSpaceAttr::isValidAtomicXchg(
Type type, ptr::AtomicOrdering successOrdering,
ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
function_ref<InFlightDiagnostic()> emitError) const {
return success();
}

LogicalResult GenericSpaceAttr::isValidAddrSpaceCast(
Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once the `addrspace_cast` op is added to the
// dialect.
assert(false && "unimplemented, see TODO in the source.");
return failure();
}

LogicalResult GenericSpaceAttr::isValidPtrIntCast(
Type intLikeTy, Type ptrLikeTy,
function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once the int-cast ops are added to the dialect.
assert(false && "unimplemented, see TODO in the source.");
return failure();
}

//===----------------------------------------------------------------------===//
// SpecAttr
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

#include "mlir/Dialect/Ptr/IR/PtrOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
Expand All @@ -39,6 +41,31 @@ void PtrDialect::initialize() {
>();
}

//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//

/// Fold: ptradd ptr + 0 -> ptr
OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
Attribute attr = adaptor.getOffset();
if (!attr)
return nullptr;
if (llvm::APInt value; m_ConstantInt(&value).match(attr) && value.isZero())
return getBase();
return nullptr;
}

//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//

llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
if (layout)
return layout->getTypeSize(getElementType());
DataLayout dl = DataLayout::closest(*this);
return dl.getTypeSize(getElementType());
}

//===----------------------------------------------------------------------===//
// Pointer API.
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Dialect/Ptr/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: mlir-opt --canonicalize %s | FileCheck %s

/// Check `ptr_add` canonicalizer patterns.

// CHECK-LABEL: @zero_offset
// CHECK-SAME: (%[[PTR_0:.*]]: !ptr.ptr<#ptr.generic_space>)
func.func @zero_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
// CHECK-NOT: index.constant
// CHECK-NOT: ptr.ptr_add
// CHECK: return %[[PTR_0]] : !ptr.ptr<#ptr.generic_space>
// CHECK: }
%off = index.constant 0
%res0 = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
return %res0 : !ptr.ptr<#ptr.generic_space>
}
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Ptr/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s

/// Check op assembly.
// CHECK-LABEL: @ptr_add_type_offset
func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
// CHECK: ptr.type_offset f32 : index
// CHECK-NEXT: ptr.ptr_add %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
// CHECK-NEXT: ptr.ptr_add %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
// CHECK-NEXT: ptr.ptr_add nusw %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
// CHECK-NEXT: ptr.ptr_add nuw %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
// CHECK-NEXT: ptr.ptr_add inbounds %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
%off = ptr.type_offset f32 : index
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
%res0 = ptr.ptr_add none %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
%res1 = ptr.ptr_add nusw %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
%res2 = ptr.ptr_add nuw %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
%res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
return %res : !ptr.ptr<#ptr.generic_space>
}
Loading