@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
308
308
}
309
309
};
310
310
311
+ // / Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
312
+ class ExtractAlignedPointerAsIndexOpPattern
313
+ : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
314
+ public:
315
+ using OpConversionPattern::OpConversionPattern;
316
+
317
+ LogicalResult
318
+ matchAndRewrite (memref::ExtractAlignedPointerAsIndexOp extractOp,
319
+ OpAdaptor adaptor,
320
+ ConversionPatternRewriter &rewriter) const override ;
321
+ };
311
322
} // namespace
312
323
313
324
// ===----------------------------------------------------------------------===//
@@ -922,17 +933,32 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
922
933
return success ();
923
934
}
924
935
936
+ // ===----------------------------------------------------------------------===//
937
+ // ExtractAlignedPointerAsIndexOp
938
+ // ===----------------------------------------------------------------------===//
939
+
940
+ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite (
941
+ memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
942
+ ConversionPatternRewriter &rewriter) const {
943
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
944
+ Type indexType = typeConverter.getIndexType ();
945
+ rewriter.replaceOpWithNewOp <spirv::ConvertPtrToUOp>(extractOp, indexType,
946
+ adaptor.getSource ());
947
+ return success ();
948
+ }
949
+
925
950
// ===----------------------------------------------------------------------===//
926
951
// Pattern population
927
952
// ===----------------------------------------------------------------------===//
928
953
929
954
namespace mlir {
930
955
void populateMemRefToSPIRVPatterns (SPIRVTypeConverter &typeConverter,
931
956
RewritePatternSet &patterns) {
932
- patterns.add <AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
933
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
934
- LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
935
- ReinterpretCastPattern, CastPattern>(typeConverter,
936
- patterns.getContext ());
957
+ patterns
958
+ .add <AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
959
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
960
+ MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
961
+ CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
962
+ typeConverter, patterns.getContext ());
937
963
}
938
964
} // namespace mlir
0 commit comments