@@ -307,6 +307,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
307
307
}
308
308
};
309
309
310
+ // / Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
311
+ class ExtractAlignedPointerAsIndexOpPattern final
312
+ : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
313
+ public:
314
+ using OpConversionPattern::OpConversionPattern;
315
+
316
+ LogicalResult
317
+ matchAndRewrite (memref::ExtractAlignedPointerAsIndexOp extractOp,
318
+ OpAdaptor adaptor,
319
+ ConversionPatternRewriter &rewriter) const override ;
320
+ };
310
321
} // namespace
311
322
312
323
// ===----------------------------------------------------------------------===//
@@ -921,17 +932,32 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
921
932
return success ();
922
933
}
923
934
935
+ // ===----------------------------------------------------------------------===//
936
+ // ExtractAlignedPointerAsIndexOp
937
+ // ===----------------------------------------------------------------------===//
938
+
939
+ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite (
940
+ memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
941
+ ConversionPatternRewriter &rewriter) const {
942
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
943
+ Type indexType = typeConverter.getIndexType ();
944
+ rewriter.replaceOpWithNewOp <spirv::ConvertPtrToUOp>(extractOp, indexType,
945
+ adaptor.getSource ());
946
+ return success ();
947
+ }
948
+
924
949
// ===----------------------------------------------------------------------===//
925
950
// Pattern population
926
951
// ===----------------------------------------------------------------------===//
927
952
928
953
namespace mlir {
929
954
void populateMemRefToSPIRVPatterns (const SPIRVTypeConverter &typeConverter,
930
955
RewritePatternSet &patterns) {
931
- patterns.add <AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
932
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
933
- LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
934
- ReinterpretCastPattern, CastPattern>(typeConverter,
935
- patterns.getContext ());
956
+ patterns
957
+ .add <AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
958
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959
+ MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960
+ CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961
+ typeConverter, patterns.getContext ());
936
962
}
937
963
} // namespace mlir
0 commit comments