Skip to content

emit inbounds and nuw attributes in memref. #138984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ class ConvertToLLVMPattern : public ConversionPattern {

// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
ValueRange indices,
ConversionPatternRewriter &rewriter) const;
Value getStridedElementPtr(
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
Value memRefDesc, ValueRange indices,
LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;

/// Returns if the given memref type is convertible to LLVM and has an
/// identity layout map.
Expand Down
14 changes: 12 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,12 @@ def LoadOp : MemRef_Op<"load",
The `load` op reads an element from a memref at the specified indices.

The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`
be in-bounds: `0 <= idx < dim_size`.

Lowerings of `memref.load` may emit attributes, e.g. `inbouds` + `nuw`
when converting to LLVM's `llvm.getelementptr`, that would cause undefined
behavior if indices are out of bounds or if computing the offset in the
memref would cause signed overflow of the `index` type.

The single result of `memref.load` is a value with the same type as the
element type of the memref.
Expand Down Expand Up @@ -1881,7 +1886,12 @@ def MemRef_StoreOp : MemRef_Op<"store",
The `store` op stores an element into a memref at the specified indices.

The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`
be in-bounds: `0 <= idx < dim_size`.

Lowerings of `memref.store` may emit attributes, e.g. `inbouds` + `nuw`
when converting to LLVM's `llvm.getelementptr`, that would cause undefined
behavior if indices are out of bounds or if computing the offset in the
memref would cause signed overflow of the `index` type.

A set `nontemporal` attribute indicates that this store is not expected to
be reused in the cache. For details, refer to the
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,10 +1118,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
return op.emitOpError("chipset unsupported element size");

Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()), rewriter);
Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
(adaptor.getDstIndices()), rewriter);
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()));
Value dstPtr =
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
(adaptor.getDstIndices()));

rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI64Type(), sliceIndex);
return getStridedElementPtr(
loc, llvm::cast<MemRefType>(tileMemory.getType()),
descriptor.getResult(0), {sliceIndexI64, zero},
static_cast<ConversionPatternRewriter &>(rewriter));
static_cast<ConversionPatternRewriter &>(rewriter), loc,
llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
{sliceIndexI64, zero});
}

/// Emits an in-place swap of a slice of a tile in ZA and a slice of a
Expand Down Expand Up @@ -507,9 +507,9 @@ struct LoadTileSliceConversion
if (!tileId)
return failure();

Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = this->getStridedElementPtr(
rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
adaptor.getIndices());

auto tileSlice = loadTileSliceOp.getTileSliceIndex();

Expand Down Expand Up @@ -554,8 +554,8 @@ struct StoreTileSliceConversion

// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
Value ptr = this->getStridedElementPtr(
loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
adaptor.getIndices(), rewriter);
rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
adaptor.getIndices());

auto tileSlice = storeTileSliceOp.getTileSliceIndex();

Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ struct WmmaLoadOpToNVVMLowering

// Create nvvm.mma_load op according to the operand types.
Value dataPtr = getStridedElementPtr(
loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
rewriter, loc,
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
adaptor.getSrcMemref(), adaptor.getIndices());

Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
Expand Down Expand Up @@ -177,9 +178,9 @@ struct WmmaStoreOpToNVVMLowering
}

Value dataPtr = getStridedElementPtr(
loc,
rewriter, loc,
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
adaptor.getDstMemref(), adaptor.getIndices());
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
}

Value ConvertToLLVMPattern::getStridedElementPtr(
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
Value memRefDesc, ValueRange indices,
LLVM::GEPNoWrapFlags noWrapFlags) const {

auto [strides, offset] = type.getStridesAndOffset();

Expand Down Expand Up @@ -91,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
return index ? rewriter.create<LLVM::GEPOp>(
loc, elementPtrType,
getTypeConverter()->convertType(type.getElementType()),
base, index)
base, index, noWrapFlags)
: base;
}

Expand Down
36 changes: 23 additions & 13 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ namespace mlir {

using namespace mlir;

static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;

namespace {

static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
Expand Down Expand Up @@ -420,8 +423,8 @@ struct AssumeAlignmentOpLowering
auto loc = op.getLoc();

auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
/*indices=*/{});

// Emit llvm.assume(true) ["align"(memref, alignment)].
// This is more direct than ptrtoint-based checks, is explicitly supported,
Expand Down Expand Up @@ -644,8 +647,8 @@ struct GenericAtomicRMWOpLowering
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
auto dataPtr = getStridedElementPtr(
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
Value init = rewriter.create<LLVM::LoadOp>(
loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
Expand Down Expand Up @@ -829,9 +832,12 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
ConversionPatternRewriter &rewriter) const override {
auto type = loadOp.getMemRefType();

Value dataPtr =
getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
// Per memref.load spec, the indices must be in-bounds:
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
adaptor.getMemref(),
adaptor.getIndices(), kNoWrapFlags);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
false, loadOp.getNontemporal());
Expand All @@ -849,8 +855,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();

Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
// Per memref.store spec, the indices must be in-bounds:
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
Value dataPtr =
getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
adaptor.getIndices(), kNoWrapFlags);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
0, false, op.getNontemporal());
return success();
Expand All @@ -868,8 +878,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
auto type = prefetchOp.getMemRefType();
auto loc = prefetchOp.getLoc();

Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
Value dataPtr = getStridedElementPtr(
rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());

// Replace with llvm.prefetch.
IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
Expand Down Expand Up @@ -1809,8 +1819,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
if (failed(memRefType.getStridesAndOffset(strides, offset)))
return failure();
auto dataPtr =
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
adaptor.getMemref(), adaptor.getIndices());
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
LLVM::AtomicOrdering::acq_rel);
Expand Down
27 changes: 14 additions & 13 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {

auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
Value srcPtr =
getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
adaptor.getIndices(), rewriter);
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
Expand Down Expand Up @@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
Location loc = op.getLoc();
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
Value dstPtr =
getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
adaptor.getDstIndices(), rewriter);
getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
adaptor.getDst(), adaptor.getDstIndices());
FailureOr<unsigned> dstAddressSpace =
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
if (failed(dstAddressSpace))
Expand All @@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
return rewriter.notifyMatchFailure(
loc, "source memref address space not convertible to integer");

Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
adaptor.getSrcIndices(), rewriter);
Value scrPtr =
getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
adaptor.getSrcIndices());
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
Expand Down Expand Up @@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
MemRefType mbarrierMemrefType =
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
return ConvertToLLVMPattern::getStridedElementPtr(
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
}
};

Expand Down Expand Up @@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
adaptor.getDst(), {}, rewriter);
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
adaptor.getDst(), {});
Value barrier =
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Expand All @@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
adaptor.getSrc(), {}, rewriter);
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
adaptor.getSrc(), {});
SmallVector<Value> coords = adaptor.getCoordinates();
for (auto [index, value] : llvm::enumerate(coords)) {
coords[index] = truncToI32(b, value);
Expand Down Expand Up @@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
Value leadDim = makeConst(leadDimVal);

Value baseAddr = getStridedElementPtr(
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
adaptor.getTensor(), {}, rewriter);
rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
adaptor.getTensor(), {});
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
// Just use 14 bits for base address
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
// Resolve address.
auto vtype = cast<VectorType>(
this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value dataPtr = this->getStridedElementPtr(
rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
rewriter);
return success();
Expand Down Expand Up @@ -337,8 +337,8 @@ class VectorGatherOpConversion
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");

// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
Value base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
Expand Down Expand Up @@ -393,8 +393,8 @@ class VectorScatterOpConversion
"could not resolve alignment");

// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
Expand Down Expand Up @@ -428,8 +428,8 @@ class VectorExpandLoadOpConversion

// Resolve address.
auto vtype = typeConverter->convertType(expand.getVectorType());
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());

rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
Expand All @@ -450,8 +450,8 @@ class VectorCompressStoreOpConversion
MemRefType memRefType = compress.getMemRefType();

// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());

rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
if (failed(stride))
return failure();
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
adaptor.getBase(), adaptor.getIndices());
Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
op, resType, tsz.first, tsz.second, ptr, stride.value());
Expand All @@ -131,8 +131,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
if (failed(stride))
return failure();
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
return success();
Expand Down
Loading