Skip to content

Commit c96a224

Browse files
committed
emit inbounds and nuw attributes in memref.
Now that MLIR accepts nuw and nusw in getelementptr, this patch emits the inbounds and nuw attributes when lower memref to LLVM in load and store operators. This patch also strengthens the memref.load and memref.store spec about undefined behaviour during lowering. This patch also lifts the |rewriter| parameter in getStridedElementPtr ahead so that LLVM::GEPNoWrapFlags can be added at the end with a default value and grouped together with other operators' parameters. fixes: iree-org/iree#20483 Signed-off-by: Lin, Peiyong <[email protected]>
1 parent 1043810 commit c96a224

File tree

14 files changed

+103
-77
lines changed

14 files changed

+103
-77
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
8383

8484
// This is a strided getElementPtr variant that linearizes subscripts as:
8585
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
86-
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
87-
ValueRange indices,
88-
ConversionPatternRewriter &rewriter) const;
86+
Value getStridedElementPtr(
87+
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
88+
Value memRefDesc, ValueRange indices,
89+
LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
8990

9091
/// Returns if the given memref type is convertible to LLVM and has an
9192
/// identity layout map.

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,12 @@ def LoadOp : MemRef_Op<"load",
11871187
The `load` op reads an element from a memref at the specified indices.
11881188

11891189
The number of indices must match the rank of the memref. The indices must
1190-
be in-bounds: `0 <= idx < dim_size`
1190+
be in-bounds: `0 <= idx < dim_size`.
1191+
1192+
Lowerings of `memref.load` may emit attributes, e.g. `inbouds` + `nuw`
1193+
when converting to LLVM's `llvm.getelementptr`, that would cause undefined
1194+
behavior if indices are out of bounds or if computing the offset in the
1195+
memref would cause signed overflow of the `index` type.
11911196

11921197
The single result of `memref.load` is a value with the same type as the
11931198
element type of the memref.
@@ -1881,7 +1886,12 @@ def MemRef_StoreOp : MemRef_Op<"store",
18811886
The `store` op stores an element into a memref at the specified indices.
18821887

18831888
The number of indices must match the rank of the memref. The indices must
1884-
be in-bounds: `0 <= idx < dim_size`
1889+
be in-bounds: `0 <= idx < dim_size`.
1890+
1891+
Lowerings of `memref.store` may emit attributes, e.g. `inbouds` + `nuw`
1892+
when converting to LLVM's `llvm.getelementptr`, that would cause undefined
1893+
behavior if indices are out of bounds or if computing the offset in the
1894+
memref would cause signed overflow of the `index` type.
18851895

18861896
A set `nontemporal` attribute indicates that this store is not expected to
18871897
be reused in the cache. For details, refer to the

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,10 +1118,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
11181118
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
11191119
return op.emitOpError("chipset unsupported element size");
11201120

1121-
Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
1122-
(adaptor.getSrcIndices()), rewriter);
1123-
Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
1124-
(adaptor.getDstIndices()), rewriter);
1121+
Value srcPtr =
1122+
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1123+
(adaptor.getSrcIndices()));
1124+
Value dstPtr =
1125+
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1126+
(adaptor.getDstIndices()));
11251127

11261128
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
11271129
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
299299
auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
300300
loc, rewriter.getI64Type(), sliceIndex);
301301
return getStridedElementPtr(
302-
loc, llvm::cast<MemRefType>(tileMemory.getType()),
303-
descriptor.getResult(0), {sliceIndexI64, zero},
304-
static_cast<ConversionPatternRewriter &>(rewriter));
302+
static_cast<ConversionPatternRewriter &>(rewriter), loc,
303+
llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
304+
{sliceIndexI64, zero});
305305
}
306306

307307
/// Emits an in-place swap of a slice of a tile in ZA and a slice of a
@@ -507,9 +507,9 @@ struct LoadTileSliceConversion
507507
if (!tileId)
508508
return failure();
509509

510-
Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
511-
adaptor.getBase(),
512-
adaptor.getIndices(), rewriter);
510+
Value ptr = this->getStridedElementPtr(
511+
rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
512+
adaptor.getIndices());
513513

514514
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
515515

@@ -554,8 +554,8 @@ struct StoreTileSliceConversion
554554

555555
// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
556556
Value ptr = this->getStridedElementPtr(
557-
loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558-
adaptor.getIndices(), rewriter);
557+
rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558+
adaptor.getIndices());
559559

560560
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
561561

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ struct WmmaLoadOpToNVVMLowering
122122

123123
// Create nvvm.mma_load op according to the operand types.
124124
Value dataPtr = getStridedElementPtr(
125-
loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
126-
adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
125+
rewriter, loc,
126+
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
127+
adaptor.getSrcMemref(), adaptor.getIndices());
127128

128129
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
129130
loc, rewriter.getI32Type(),
@@ -177,9 +178,9 @@ struct WmmaStoreOpToNVVMLowering
177178
}
178179

179180
Value dataPtr = getStridedElementPtr(
180-
loc,
181+
rewriter, loc,
181182
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
182-
adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
183+
adaptor.getDstMemref(), adaptor.getIndices());
183184
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
184185
loc, rewriter.getI32Type(),
185186
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
5959
}
6060

6161
Value ConvertToLLVMPattern::getStridedElementPtr(
62-
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
63-
ConversionPatternRewriter &rewriter) const {
62+
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
63+
Value memRefDesc, ValueRange indices,
64+
LLVM::GEPNoWrapFlags noWrapFlags) const {
6465

6566
auto [strides, offset] = type.getStridesAndOffset();
6667

@@ -91,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
9192
return index ? rewriter.create<LLVM::GEPOp>(
9293
loc, elementPtrType,
9394
getTypeConverter()->convertType(type.getElementType()),
94-
base, index)
95+
base, index, noWrapFlags)
9596
: base;
9697
}
9798

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ namespace mlir {
3535

3636
using namespace mlir;
3737

38+
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
39+
LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
40+
3841
namespace {
3942

4043
static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
@@ -420,8 +423,8 @@ struct AssumeAlignmentOpLowering
420423
auto loc = op.getLoc();
421424

422425
auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
423-
Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
424-
rewriter);
426+
Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
427+
/*indices=*/{});
425428

426429
// Emit llvm.assume(true) ["align"(memref, alignment)].
427430
// This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -644,8 +647,8 @@ struct GenericAtomicRMWOpLowering
644647
// Compute the loaded value and branch to the loop block.
645648
rewriter.setInsertionPointToEnd(initBlock);
646649
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
647-
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
648-
adaptor.getIndices(), rewriter);
650+
auto dataPtr = getStridedElementPtr(
651+
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
649652
Value init = rewriter.create<LLVM::LoadOp>(
650653
loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
651654
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
@@ -829,9 +832,12 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
829832
ConversionPatternRewriter &rewriter) const override {
830833
auto type = loadOp.getMemRefType();
831834

832-
Value dataPtr =
833-
getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
834-
adaptor.getIndices(), rewriter);
835+
// Per memref.load spec, the indices must be in-bounds:
836+
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
837+
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
838+
Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
839+
adaptor.getMemref(),
840+
adaptor.getIndices(), kNoWrapFlags);
835841
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
836842
loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
837843
false, loadOp.getNontemporal());
@@ -849,8 +855,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
849855
ConversionPatternRewriter &rewriter) const override {
850856
auto type = op.getMemRefType();
851857

852-
Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
853-
adaptor.getIndices(), rewriter);
858+
// Per memref.store spec, the indices must be in-bounds:
859+
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
860+
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
861+
Value dataPtr =
862+
getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
863+
adaptor.getIndices(), kNoWrapFlags);
854864
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
855865
0, false, op.getNontemporal());
856866
return success();
@@ -868,8 +878,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
868878
auto type = prefetchOp.getMemRefType();
869879
auto loc = prefetchOp.getLoc();
870880

871-
Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
872-
adaptor.getIndices(), rewriter);
881+
Value dataPtr = getStridedElementPtr(
882+
rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
873883

874884
// Replace with llvm.prefetch.
875885
IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
@@ -1809,8 +1819,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
18091819
if (failed(memRefType.getStridesAndOffset(strides, offset)))
18101820
return failure();
18111821
auto dataPtr =
1812-
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1813-
adaptor.getIndices(), rewriter);
1822+
getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1823+
adaptor.getMemref(), adaptor.getIndices());
18141824
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
18151825
atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
18161826
LLVM::AtomicOrdering::acq_rel);

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
283283

284284
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285285
Value srcPtr =
286-
getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287-
adaptor.getIndices(), rewriter);
286+
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
287+
adaptor.getSrcMemref(), adaptor.getIndices());
288288
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289289
ldMatrixResultType, srcPtr,
290290
/*num=*/op.getNumTiles(),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
661661
Location loc = op.getLoc();
662662
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
663663
Value dstPtr =
664-
getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
665-
adaptor.getDstIndices(), rewriter);
664+
getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
665+
adaptor.getDst(), adaptor.getDstIndices());
666666
FailureOr<unsigned> dstAddressSpace =
667667
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
668668
if (failed(dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
676676
return rewriter.notifyMatchFailure(
677677
loc, "source memref address space not convertible to integer");
678678

679-
Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
680-
adaptor.getSrcIndices(), rewriter);
679+
Value scrPtr =
680+
getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
681+
adaptor.getSrcIndices());
681682
// Intrinsics takes a global pointer so we need an address space cast.
682683
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
683684
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
814815
MemRefType mbarrierMemrefType =
815816
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
816817
return ConvertToLLVMPattern::getStridedElementPtr(
817-
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
818+
rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
818819
}
819820
};
820821

@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
995996
ConversionPatternRewriter &rewriter) const override {
996997
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
997998
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
998-
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
999-
adaptor.getDst(), {}, rewriter);
999+
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1000+
adaptor.getDst(), {});
10001001
Value barrier =
10011002
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
10021003
adaptor.getMbarId(), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
10211022
ConversionPatternRewriter &rewriter) const override {
10221023
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
10231024
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1024-
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1025-
adaptor.getSrc(), {}, rewriter);
1025+
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1026+
adaptor.getSrc(), {});
10261027
SmallVector<Value> coords = adaptor.getCoordinates();
10271028
for (auto [index, value] : llvm::enumerate(coords)) {
10281029
coords[index] = truncToI32(b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
10831084
Value leadDim = makeConst(leadDimVal);
10841085

10851086
Value baseAddr = getStridedElementPtr(
1086-
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1087-
adaptor.getTensor(), {}, rewriter);
1087+
rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1088+
adaptor.getTensor(), {});
10881089
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
10891090
// Just use 14 bits for base address
10901091
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
289289
// Resolve address.
290290
auto vtype = cast<VectorType>(
291291
this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
292-
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
293-
adaptor.getIndices(), rewriter);
292+
Value dataPtr = this->getStridedElementPtr(
293+
rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
294294
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
295295
rewriter);
296296
return success();
@@ -337,8 +337,8 @@ class VectorGatherOpConversion
337337
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
338338

339339
// Resolve address.
340-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
341-
adaptor.getIndices(), rewriter);
340+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
341+
adaptor.getBase(), adaptor.getIndices());
342342
Value base = adaptor.getBase();
343343
Value ptrs =
344344
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
@@ -393,8 +393,8 @@ class VectorScatterOpConversion
393393
"could not resolve alignment");
394394

395395
// Resolve address.
396-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397-
adaptor.getIndices(), rewriter);
396+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
397+
adaptor.getBase(), adaptor.getIndices());
398398
Value ptrs =
399399
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
400400
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
@@ -428,8 +428,8 @@ class VectorExpandLoadOpConversion
428428

429429
// Resolve address.
430430
auto vtype = typeConverter->convertType(expand.getVectorType());
431-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
432-
adaptor.getIndices(), rewriter);
431+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
432+
adaptor.getBase(), adaptor.getIndices());
433433

434434
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
435435
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
@@ -450,8 +450,8 @@ class VectorCompressStoreOpConversion
450450
MemRefType memRefType = compress.getMemRefType();
451451

452452
// Resolve address.
453-
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
454-
adaptor.getIndices(), rewriter);
453+
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
454+
adaptor.getBase(), adaptor.getIndices());
455455

456456
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
457457
compress, adaptor.getValueToStore(), ptr, adaptor.getMask());

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
105105
if (failed(stride))
106106
return failure();
107107
// Replace operation with intrinsic.
108-
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
109-
adaptor.getIndices(), rewriter);
108+
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
109+
adaptor.getBase(), adaptor.getIndices());
110110
Type resType = typeConverter->convertType(tType);
111111
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
112112
op, resType, tsz.first, tsz.second, ptr, stride.value());
@@ -131,8 +131,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
131131
if (failed(stride))
132132
return failure();
133133
// Replace operation with intrinsic.
134-
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
135-
adaptor.getIndices(), rewriter);
134+
Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
135+
adaptor.getBase(), adaptor.getIndices());
136136
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
137137
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
138138
return success();

0 commit comments

Comments
 (0)