Skip to content

Commit 49b9531

Browse files
committed
emit inbounds and nuw attributes in memref.
Now that MLIR accepts nuw and nusw in getelementptr, this patch emits the inbounds and nuw attributes when lower memref to LLVM in load and store operators. It is guaranteed that memref.load and memref.store must be inbounds: `0 <= idx < dim_size`. This patch also lifts the |rewriter| parameter in getStridedElementPtr ahead so that LLVM::GEPNoWrapFlags can be added at the end with a default value and grouped together with other operators' parameters. fixes: iree-org/iree#20483 Signed-off-by: Lin, Peiyong <[email protected]>
1 parent eb6d51a commit 49b9531

File tree

13 files changed

+83
-75
lines changed

13 files changed

+83
-75
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
8383

8484
// This is a strided getElementPtr variant that linearizes subscripts as:
8585
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
86-
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
87-
ValueRange indices,
88-
ConversionPatternRewriter &rewriter) const;
86+
Value getStridedElementPtr(
87+
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
88+
Value memRefDesc, ValueRange indices,
89+
LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
8990

9091
/// Returns if the given memref type is convertible to LLVM and has an
9192
/// identity layout map.

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,10 +1118,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
11181118
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
11191119
return op.emitOpError("chipset unsupported element size");
11201120

1121-
Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
1122-
(adaptor.getSrcIndices()), rewriter);
1123-
Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
1124-
(adaptor.getDstIndices()), rewriter);
1121+
Value srcPtr =
1122+
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1123+
(adaptor.getSrcIndices()));
1124+
Value dstPtr =
1125+
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1126+
(adaptor.getDstIndices()));
11251127

11261128
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
11271129
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
299299
auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
300300
loc, rewriter.getI64Type(), sliceIndex);
301301
return getStridedElementPtr(
302-
loc, llvm::cast<MemRefType>(tileMemory.getType()),
303-
descriptor.getResult(0), {sliceIndexI64, zero},
304-
static_cast<ConversionPatternRewriter &>(rewriter));
302+
static_cast<ConversionPatternRewriter &>(rewriter), loc,
303+
llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
304+
{sliceIndexI64, zero});
305305
}
306306

307307
/// Emits an in-place swap of a slice of a tile in ZA and a slice of a
@@ -507,9 +507,9 @@ struct LoadTileSliceConversion
507507
if (!tileId)
508508
return failure();
509509

510-
Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
511-
adaptor.getBase(),
512-
adaptor.getIndices(), rewriter);
510+
Value ptr = this->getStridedElementPtr(
511+
rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
512+
adaptor.getIndices());
513513

514514
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
515515

@@ -554,8 +554,8 @@ struct StoreTileSliceConversion
554554

555555
// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
556556
Value ptr = this->getStridedElementPtr(
557-
loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558-
adaptor.getIndices(), rewriter);
557+
rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558+
adaptor.getIndices());
559559

560560
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
561561

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ struct WmmaLoadOpToNVVMLowering
122122

123123
// Create nvvm.mma_load op according to the operand types.
124124
Value dataPtr = getStridedElementPtr(
125-
loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
126-
adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
125+
rewriter, loc,
126+
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
127+
adaptor.getSrcMemref(), adaptor.getIndices());
127128

128129
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
129130
loc, rewriter.getI32Type(),
@@ -177,9 +178,9 @@ struct WmmaStoreOpToNVVMLowering
177178
}
178179

179180
Value dataPtr = getStridedElementPtr(
180-
loc,
181+
rewriter, loc,
181182
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
182-
adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
183+
adaptor.getDstMemref(), adaptor.getIndices());
183184
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
184185
loc, rewriter.getI32Type(),
185186
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
5959
}
6060

6161
Value ConvertToLLVMPattern::getStridedElementPtr(
62-
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
63-
ConversionPatternRewriter &rewriter) const {
62+
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
63+
Value memRefDesc, ValueRange indices,
64+
LLVM::GEPNoWrapFlags noWrapFlags) const {
6465

6566
auto [strides, offset] = type.getStridesAndOffset();
6667

@@ -91,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
9192
return index ? rewriter.create<LLVM::GEPOp>(
9293
loc, elementPtrType,
9394
getTypeConverter()->convertType(type.getElementType()),
94-
base, index)
95+
base, index, noWrapFlags)
9596
: base;
9697
}
9798

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,8 @@ struct AssumeAlignmentOpLowering
420420
auto loc = op.getLoc();
421421

422422
auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
423-
Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
424-
rewriter);
423+
Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
424+
/*indices=*/{});
425425

426426
// Emit llvm.assume(true) ["align"(memref, alignment)].
427427
// This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -644,8 +644,8 @@ struct GenericAtomicRMWOpLowering
644644
// Compute the loaded value and branch to the loop block.
645645
rewriter.setInsertionPointToEnd(initBlock);
646646
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
647-
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
648-
adaptor.getIndices(), rewriter);
647+
auto dataPtr = getStridedElementPtr(
648+
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
649649
Value init = rewriter.create<LLVM::LoadOp>(
650650
loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
651651
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
@@ -829,9 +829,10 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
829829
ConversionPatternRewriter &rewriter) const override {
830830
auto type = loadOp.getMemRefType();
831831

832-
Value dataPtr =
833-
getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
834-
adaptor.getIndices(), rewriter);
832+
Value dataPtr = getStridedElementPtr(
833+
rewriter, loadOp.getLoc(), type, adaptor.getMemref(),
834+
adaptor.getIndices(),
835+
LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
835836
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
836837
loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
837838
false, loadOp.getNontemporal());
@@ -849,8 +850,9 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
849850
ConversionPatternRewriter &rewriter) const override {
850851
auto type = op.getMemRefType();
851852

852-
Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
853-
adaptor.getIndices(), rewriter);
853+
Value dataPtr = getStridedElementPtr(
854+
rewriter, op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(),
855+
LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
854856
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
855857
0, false, op.getNontemporal());
856858
return success();
@@ -868,8 +870,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
868870
auto type = prefetchOp.getMemRefType();
869871
auto loc = prefetchOp.getLoc();
870872

871-
Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
872-
adaptor.getIndices(), rewriter);
873+
Value dataPtr = getStridedElementPtr(
874+
rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
873875

874876
// Replace with llvm.prefetch.
875877
IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
@@ -1809,8 +1811,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
18091811
if (failed(memRefType.getStridesAndOffset(strides, offset)))
18101812
return failure();
18111813
auto dataPtr =
1812-
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1813-
adaptor.getIndices(), rewriter);
1814+
getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1815+
adaptor.getMemref(), adaptor.getIndices());
18141816
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
18151817
atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
18161818
LLVM::AtomicOrdering::acq_rel);

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
283283

284284
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285285
Value srcPtr =
286-
getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287-
adaptor.getIndices(), rewriter);
286+
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
287+
adaptor.getSrcMemref(), adaptor.getIndices());
288288
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289289
ldMatrixResultType, srcPtr,
290290
/*num=*/op.getNumTiles(),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
661661
Location loc = op.getLoc();
662662
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
663663
Value dstPtr =
664-
getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
665-
adaptor.getDstIndices(), rewriter);
664+
getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
665+
adaptor.getDst(), adaptor.getDstIndices());
666666
FailureOr<unsigned> dstAddressSpace =
667667
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
668668
if (failed(dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
676676
return rewriter.notifyMatchFailure(
677677
loc, "source memref address space not convertible to integer");
678678

679-
Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
680-
adaptor.getSrcIndices(), rewriter);
679+
Value scrPtr =
680+
getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
681+
adaptor.getSrcIndices());
681682
// Intrinsics takes a global pointer so we need an address space cast.
682683
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
683684
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
814815
MemRefType mbarrierMemrefType =
815816
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
816817
return ConvertToLLVMPattern::getStridedElementPtr(
817-
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
818+
rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
818819
}
819820
};
820821

@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
995996
ConversionPatternRewriter &rewriter) const override {
996997
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
997998
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
998-
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
999-
adaptor.getDst(), {}, rewriter);
999+
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1000+
adaptor.getDst(), {});
10001001
Value barrier =
10011002
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
10021003
adaptor.getMbarId(), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
10211022
ConversionPatternRewriter &rewriter) const override {
10221023
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
10231024
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1024-
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1025-
adaptor.getSrc(), {}, rewriter);
1025+
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1026+
adaptor.getSrc(), {});
10261027
SmallVector<Value> coords = adaptor.getCoordinates();
10271028
for (auto [index, value] : llvm::enumerate(coords)) {
10281029
coords[index] = truncToI32(b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
10831084
Value leadDim = makeConst(leadDimVal);
10841085

10851086
Value baseAddr = getStridedElementPtr(
1086-
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1087-
adaptor.getTensor(), {}, rewriter);
1087+
rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1088+
adaptor.getTensor(), {});
10881089
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
10891090
// Just use 14 bits for base address
10901091
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
289289
// Resolve address.
290290
auto vtype = cast<VectorType>(
291291
this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
292-
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
293-
adaptor.getIndices(), rewriter);
292+
Value dataPtr = this->getStridedElementPtr(
293+
rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
294294
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
295295
rewriter);
296296
return success();
@@ -337,8 +337,8 @@ class VectorGatherOpConversion
337337
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
338338

339339
// Resolve address.
340-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
341-
adaptor.getIndices(), rewriter);
340+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
341+
adaptor.getBase(), adaptor.getIndices());
342342
Value base = adaptor.getBase();
343343
Value ptrs =
344344
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
@@ -393,8 +393,8 @@ class VectorScatterOpConversion
393393
"could not resolve alignment");
394394

395395
// Resolve address.
396-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397-
adaptor.getIndices(), rewriter);
396+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
397+
adaptor.getBase(), adaptor.getIndices());
398398
Value ptrs =
399399
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
400400
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
@@ -428,8 +428,8 @@ class VectorExpandLoadOpConversion
428428

429429
// Resolve address.
430430
auto vtype = typeConverter->convertType(expand.getVectorType());
431-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
432-
adaptor.getIndices(), rewriter);
431+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
432+
adaptor.getBase(), adaptor.getIndices());
433433

434434
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
435435
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
@@ -450,8 +450,8 @@ class VectorCompressStoreOpConversion
450450
MemRefType memRefType = compress.getMemRefType();
451451

452452
// Resolve address.
453-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
454-
adaptor.getIndices(), rewriter);
453+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
454+
adaptor.getBase(), adaptor.getIndices());
455455

456456
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
457457
compress, adaptor.getValueToStore(), ptr, adaptor.getMask());

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
105105
if (failed(stride))
106106
return failure();
107107
// Replace operation with intrinsic.
108-
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
109-
adaptor.getIndices(), rewriter);
108+
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
109+
adaptor.getBase(), adaptor.getIndices());
110110
Type resType = typeConverter->convertType(tType);
111111
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
112112
op, resType, tsz.first, tsz.second, ptr, stride.value());
@@ -131,8 +131,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
131131
if (failed(stride))
132132
return failure();
133133
// Replace operation with intrinsic.
134-
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
135-
adaptor.getIndices(), rewriter);
134+
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
135+
adaptor.getBase(), adaptor.getIndices());
136136
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
137137
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
138138
return success();

mlir/test/Conversion/FuncToLLVM/calling-convention.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 :
266266
// CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1]
267267

268268
// CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
269-
// CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]]
269+
// CHECK: %[[STOREPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR]]
270270
// CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
271271
memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
272272

@@ -295,12 +295,12 @@ func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : ind
295295
// CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1]
296296

297297
// CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
298-
// CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]]
298+
// CHECK: %[[STOREPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR]]
299299
// CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
300300
memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
301301

302302
// CHECK: %[[ALIGNEDPTR0:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
303-
// CHECK: %[[LOADPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR0]]
303+
// CHECK: %[[LOADPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR0]]
304304
// CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]]
305305
%0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32>
306306

0 commit comments

Comments
 (0)