Skip to content

Commit 04ad8d4

Browse files
authored
Emit inbounds and nuw attributes in memref. (#138984)
Now that MLIR accepts nuw and nusw in getelementptr, this patch emits the inbounds and nuw attributes when lower memref to LLVM in load and store operators. This patch also strengthens the memref.load and memref.store spec about undefined behaviour during lowering. This patch also lifts the |rewriter| parameter in getStridedElementPtr ahead so that LLVM::GEPNoWrapFlags can be added at the end with a default value and grouped together with other operators' parameters. Signed-off-by: Lin, Peiyong <[email protected]>
1 parent 11db128 commit 04ad8d4

File tree

14 files changed

+103
-77
lines changed

14 files changed

+103
-77
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
109109

110110
// This is a strided getElementPtr variant that linearizes subscripts as:
111111
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
112-
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
113-
ValueRange indices,
114-
ConversionPatternRewriter &rewriter) const;
112+
Value getStridedElementPtr(
113+
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
114+
Value memRefDesc, ValueRange indices,
115+
LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
115116

116117
/// Returns if the given memref type is convertible to LLVM and has an
117118
/// identity layout map.

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,12 @@ def LoadOp : MemRef_Op<"load",
12021202
The `load` op reads an element from a memref at the specified indices.
12031203

12041204
The number of indices must match the rank of the memref. The indices must
1205-
be in-bounds: `0 <= idx < dim_size`
1205+
be in-bounds: `0 <= idx < dim_size`.
1206+
1207+
Lowerings of `memref.load` may emit attributes, e.g. `inbouds` + `nuw`
1208+
when converting to LLVM's `llvm.getelementptr`, that would cause undefined
1209+
behavior if indices are out of bounds or if computing the offset in the
1210+
memref would cause signed overflow of the `index` type.
12061211

12071212
The single result of `memref.load` is a value with the same type as the
12081213
element type of the memref.
@@ -1896,7 +1901,12 @@ def MemRef_StoreOp : MemRef_Op<"store",
18961901
The `store` op stores an element into a memref at the specified indices.
18971902

18981903
The number of indices must match the rank of the memref. The indices must
1899-
be in-bounds: `0 <= idx < dim_size`
1904+
be in-bounds: `0 <= idx < dim_size`.
1905+
1906+
Lowerings of `memref.store` may emit attributes, e.g. `inbouds` + `nuw`
1907+
when converting to LLVM's `llvm.getelementptr`, that would cause undefined
1908+
behavior if indices are out of bounds or if computing the offset in the
1909+
memref would cause signed overflow of the `index` type.
19001910

19011911
A set `nontemporal` attribute indicates that this store is not expected to
19021912
be reused in the cache. For details, refer to the

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,10 +1117,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
11171117
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
11181118
return op.emitOpError("chipset unsupported element size");
11191119

1120-
Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
1121-
(adaptor.getSrcIndices()), rewriter);
1122-
Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
1123-
(adaptor.getDstIndices()), rewriter);
1120+
Value srcPtr =
1121+
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1122+
(adaptor.getSrcIndices()));
1123+
Value dstPtr =
1124+
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1125+
(adaptor.getDstIndices()));
11241126

11251127
rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
11261128
op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
299299
auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
300300
loc, rewriter.getI64Type(), sliceIndex);
301301
return getStridedElementPtr(
302-
loc, llvm::cast<MemRefType>(tileMemory.getType()),
303-
descriptor.getResult(0), {sliceIndexI64, zero},
304-
static_cast<ConversionPatternRewriter &>(rewriter));
302+
static_cast<ConversionPatternRewriter &>(rewriter), loc,
303+
llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
304+
{sliceIndexI64, zero});
305305
}
306306

307307
/// Emits an in-place swap of a slice of a tile in ZA and a slice of a
@@ -507,9 +507,9 @@ struct LoadTileSliceConversion
507507
if (!tileId)
508508
return failure();
509509

510-
Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
511-
adaptor.getBase(),
512-
adaptor.getIndices(), rewriter);
510+
Value ptr = this->getStridedElementPtr(
511+
rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
512+
adaptor.getIndices());
513513

514514
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
515515

@@ -554,8 +554,8 @@ struct StoreTileSliceConversion
554554

555555
// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
556556
Value ptr = this->getStridedElementPtr(
557-
loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558-
adaptor.getIndices(), rewriter);
557+
rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558+
adaptor.getIndices());
559559

560560
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
561561

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ struct WmmaLoadOpToNVVMLowering
122122

123123
// Create nvvm.mma_load op according to the operand types.
124124
Value dataPtr = getStridedElementPtr(
125-
loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
126-
adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
125+
rewriter, loc,
126+
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
127+
adaptor.getSrcMemref(), adaptor.getIndices());
127128

128129
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
129130
loc, rewriter.getI32Type(),
@@ -177,9 +178,9 @@ struct WmmaStoreOpToNVVMLowering
177178
}
178179

179180
Value dataPtr = getStridedElementPtr(
180-
loc,
181+
rewriter, loc,
181182
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
182-
adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
183+
adaptor.getDstMemref(), adaptor.getIndices());
183184
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
184185
loc, rewriter.getI32Type(),
185186
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
5959
}
6060

6161
Value ConvertToLLVMPattern::getStridedElementPtr(
62-
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
63-
ConversionPatternRewriter &rewriter) const {
62+
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
63+
Value memRefDesc, ValueRange indices,
64+
LLVM::GEPNoWrapFlags noWrapFlags) const {
6465

6566
auto [strides, offset] = type.getStridesAndOffset();
6667

@@ -91,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
9192
return index ? rewriter.create<LLVM::GEPOp>(
9293
loc, elementPtrType,
9394
getTypeConverter()->convertType(type.getElementType()),
94-
base, index)
95+
base, index, noWrapFlags)
9596
: base;
9697
}
9798

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ namespace mlir {
3535

3636
using namespace mlir;
3737

38+
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
39+
LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
40+
3841
namespace {
3942

4043
static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
@@ -420,8 +423,8 @@ struct AssumeAlignmentOpLowering
420423
auto loc = op.getLoc();
421424

422425
auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
423-
Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
424-
rewriter);
426+
Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
427+
/*indices=*/{});
425428

426429
// Emit llvm.assume(true) ["align"(memref, alignment)].
427430
// This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -643,8 +646,8 @@ struct GenericAtomicRMWOpLowering
643646
// Compute the loaded value and branch to the loop block.
644647
rewriter.setInsertionPointToEnd(initBlock);
645648
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
646-
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
647-
adaptor.getIndices(), rewriter);
649+
auto dataPtr = getStridedElementPtr(
650+
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
648651
Value init = rewriter.create<LLVM::LoadOp>(
649652
loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
650653
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
@@ -828,9 +831,12 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
828831
ConversionPatternRewriter &rewriter) const override {
829832
auto type = loadOp.getMemRefType();
830833

831-
Value dataPtr =
832-
getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
833-
adaptor.getIndices(), rewriter);
834+
// Per memref.load spec, the indices must be in-bounds:
835+
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
836+
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
837+
Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
838+
adaptor.getMemref(),
839+
adaptor.getIndices(), kNoWrapFlags);
834840
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
835841
loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
836842
false, loadOp.getNontemporal());
@@ -848,8 +854,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
848854
ConversionPatternRewriter &rewriter) const override {
849855
auto type = op.getMemRefType();
850856

851-
Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
852-
adaptor.getIndices(), rewriter);
857+
// Per memref.store spec, the indices must be in-bounds:
858+
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
859+
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
860+
Value dataPtr =
861+
getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
862+
adaptor.getIndices(), kNoWrapFlags);
853863
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
854864
0, false, op.getNontemporal());
855865
return success();
@@ -867,8 +877,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
867877
auto type = prefetchOp.getMemRefType();
868878
auto loc = prefetchOp.getLoc();
869879

870-
Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
871-
adaptor.getIndices(), rewriter);
880+
Value dataPtr = getStridedElementPtr(
881+
rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
872882

873883
// Replace with llvm.prefetch.
874884
IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
@@ -1808,8 +1818,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
18081818
if (failed(memRefType.getStridesAndOffset(strides, offset)))
18091819
return failure();
18101820
auto dataPtr =
1811-
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1812-
adaptor.getIndices(), rewriter);
1821+
getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1822+
adaptor.getMemref(), adaptor.getIndices());
18131823
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
18141824
atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
18151825
LLVM::AtomicOrdering::acq_rel);

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
283283

284284
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285285
Value srcPtr =
286-
getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287-
adaptor.getIndices(), rewriter);
286+
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
287+
adaptor.getSrcMemref(), adaptor.getIndices());
288288
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289289
ldMatrixResultType, srcPtr,
290290
/*num=*/op.getNumTiles(),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
661661
Location loc = op.getLoc();
662662
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
663663
Value dstPtr =
664-
getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
665-
adaptor.getDstIndices(), rewriter);
664+
getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
665+
adaptor.getDst(), adaptor.getDstIndices());
666666
FailureOr<unsigned> dstAddressSpace =
667667
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
668668
if (failed(dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
676676
return rewriter.notifyMatchFailure(
677677
loc, "source memref address space not convertible to integer");
678678

679-
Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
680-
adaptor.getSrcIndices(), rewriter);
679+
Value scrPtr =
680+
getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
681+
adaptor.getSrcIndices());
681682
// Intrinsics takes a global pointer so we need an address space cast.
682683
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
683684
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
814815
MemRefType mbarrierMemrefType =
815816
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
816817
return ConvertToLLVMPattern::getStridedElementPtr(
817-
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
818+
rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
818819
}
819820
};
820821

@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
995996
ConversionPatternRewriter &rewriter) const override {
996997
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
997998
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
998-
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
999-
adaptor.getDst(), {}, rewriter);
999+
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1000+
adaptor.getDst(), {});
10001001
Value barrier =
10011002
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
10021003
adaptor.getMbarId(), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
10211022
ConversionPatternRewriter &rewriter) const override {
10221023
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
10231024
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1024-
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1025-
adaptor.getSrc(), {}, rewriter);
1025+
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1026+
adaptor.getSrc(), {});
10261027
SmallVector<Value> coords = adaptor.getCoordinates();
10271028
for (auto [index, value] : llvm::enumerate(coords)) {
10281029
coords[index] = truncToI32(b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
10831084
Value leadDim = makeConst(leadDimVal);
10841085

10851086
Value baseAddr = getStridedElementPtr(
1086-
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1087-
adaptor.getTensor(), {}, rewriter);
1087+
rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1088+
adaptor.getTensor(), {});
10881089
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
10891090
// Just use 14 bits for base address
10901091
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
289289
// Resolve address.
290290
auto vtype = cast<VectorType>(
291291
this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
292-
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
293-
adaptor.getIndices(), rewriter);
292+
Value dataPtr = this->getStridedElementPtr(
293+
rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
294294
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
295295
rewriter);
296296
return success();
@@ -337,8 +337,8 @@ class VectorGatherOpConversion
337337
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
338338

339339
// Resolve address.
340-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
341-
adaptor.getIndices(), rewriter);
340+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
341+
adaptor.getBase(), adaptor.getIndices());
342342
Value base = adaptor.getBase();
343343
Value ptrs =
344344
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
@@ -393,8 +393,8 @@ class VectorScatterOpConversion
393393
"could not resolve alignment");
394394

395395
// Resolve address.
396-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397-
adaptor.getIndices(), rewriter);
396+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
397+
adaptor.getBase(), adaptor.getIndices());
398398
Value ptrs =
399399
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
400400
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
@@ -428,8 +428,8 @@ class VectorExpandLoadOpConversion
428428

429429
// Resolve address.
430430
auto vtype = typeConverter->convertType(expand.getVectorType());
431-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
432-
adaptor.getIndices(), rewriter);
431+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
432+
adaptor.getBase(), adaptor.getIndices());
433433

434434
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
435435
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
@@ -450,8 +450,8 @@ class VectorCompressStoreOpConversion
450450
MemRefType memRefType = compress.getMemRefType();
451451

452452
// Resolve address.
453-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
454-
adaptor.getIndices(), rewriter);
453+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
454+
adaptor.getBase(), adaptor.getIndices());
455455

456456
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
457457
compress, adaptor.getValueToStore(), ptr, adaptor.getMask());

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
105105
if (failed(stride))
106106
return failure();
107107
// Replace operation with intrinsic.
108-
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
109-
adaptor.getIndices(), rewriter);
108+
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
109+
adaptor.getBase(), adaptor.getIndices());
110110
Type resType = typeConverter->convertType(tType);
111111
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
112112
op, resType, tsz.first, tsz.second, ptr, stride.value());
@@ -131,8 +131,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
131131
if (failed(stride))
132132
return failure();
133133
// Replace operation with intrinsic.
134-
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
135-
adaptor.getIndices(), rewriter);
134+
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
135+
adaptor.getBase(), adaptor.getIndices());
136136
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
137137
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
138138
return success();

0 commit comments

Comments
 (0)